mz_transform/predicate_pushdown.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes predicates down through other operators.
11//!
12//! This action generally improves the quality of the query, in that selective per-record
13//! filters reduce the volume of data before they arrive at more expensive operators.
14//!
15//!
16//! The one time when this action might not improve the quality of a query is
17//! if a filter gets pushed down on an arrangement because that blocks arrangement
18//! reuse. It assumed that actions that need an arrangement are responsible for
19//! lifting filters out of the way.
20//!
21//! Predicate pushdown will not push down literal errors, unless it is certain that
22//! the literal errors will be unconditionally evaluated. For example, the pushdown
23//! will not happen if not all predicates can be pushed down (e.g. reduce and map),
24//! or if we are not certain that the input is non-empty (e.g. join).
25//! Note that this is not addressing the problem in its full generality, because this problem can
26//! occur with any function call that might error (although much more rarely than with literal
27//! errors). See <https://github.com/MaterializeInc/database-issues/issues/4972#issuecomment-1547391011>
28//!
29//! ```rust
30//! use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr};
31//! use mz_ore::id_gen::IdGen;
32//! use mz_repr::{ColumnType, Datum, RelationType, ScalarType};
33//! use mz_repr::optimize::OptimizerFeatures;
34//! use mz_transform::{typecheck, Transform, TransformCtx};
35//! use mz_transform::dataflow::DataflowMetainfo;
36//!
37//! use mz_transform::predicate_pushdown::PredicatePushdown;
38//!
39//! let input1 = MirRelationExpr::constant(vec![], RelationType::new(vec![
40//! ScalarType::Bool.nullable(false),
41//! ]));
42//! let input2 = MirRelationExpr::constant(vec![], RelationType::new(vec![
43//! ScalarType::Bool.nullable(false),
44//! ]));
45//! let input3 = MirRelationExpr::constant(vec![], RelationType::new(vec![
46//! ScalarType::Bool.nullable(false),
47//! ]));
48//! let join = MirRelationExpr::join(
49//! vec![input1.clone(), input2.clone(), input3.clone()],
50//! vec![vec![(0, 0), (2, 0)].into_iter().collect()],
51//! );
52//!
53//! let predicate0 = MirScalarExpr::column(0);
54//! let predicate1 = MirScalarExpr::column(1);
55//! let predicate01 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(2), BinaryFunc::AddInt64);
56//! let predicate012 = MirScalarExpr::literal_false();
57//!
58//! let mut expr = join.filter(
59//! vec![
60//! predicate0.clone(),
61//! predicate1.clone(),
62//! predicate01.clone(),
63//! predicate012.clone(),
64//! ]);
65//!
66//! let features = OptimizerFeatures::default();
67//! let typecheck_ctx = typecheck::empty_context();
68//! let mut df_meta = DataflowMetainfo::default();
69//! let mut transform_ctx = TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None);
70//!
71//! PredicatePushdown::default().transform(&mut expr, &mut transform_ctx);
72//!
73//! let predicate00 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(0), BinaryFunc::AddInt64);
74//! let expected_expr = MirRelationExpr::join(
75//! vec![
76//! input1.clone().filter(vec![predicate0.clone(), predicate00.clone()]),
77//! input2.clone().filter(vec![predicate0.clone()]),
78//! input3.clone().filter(vec![predicate0, predicate00])
79//! ],
80//! vec![vec![(0, 0), (2, 0)].into_iter().collect()],
81//! ).filter(vec![predicate012]);
82//! assert_eq!(expected_expr, expr)
83//! ```
84
85use std::collections::{BTreeMap, BTreeSet};
86
87use itertools::Itertools;
88use mz_expr::visit::{Visit, VisitChildren};
89use mz_expr::{
90 AggregateFunc, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
91 VariadicFunc, func,
92};
93use mz_ore::soft_assert_eq_no_log;
94use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
95use mz_repr::{ColumnType, Datum, ScalarType};
96
97use crate::{TransformCtx, TransformError};
98
99/// Pushes predicates down through other operators.
100#[derive(Debug)]
101pub struct PredicatePushdown {
102 recursion_guard: RecursionGuard,
103}
104
105impl Default for PredicatePushdown {
106 fn default() -> PredicatePushdown {
107 PredicatePushdown {
108 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
109 }
110 }
111}
112
113impl CheckedRecursion for PredicatePushdown {
114 fn recursion_guard(&self) -> &RecursionGuard {
115 &self.recursion_guard
116 }
117}
118
119impl crate::Transform for PredicatePushdown {
120 fn name(&self) -> &'static str {
121 "PredicatePushdown"
122 }
123
124 #[mz_ore::instrument(
125 target = "optimizer",
126 level = "debug",
127 fields(path.segment = "predicate_pushdown")
128 )]
129 fn actually_perform_transform(
130 &self,
131 relation: &mut MirRelationExpr,
132 _: &mut TransformCtx,
133 ) -> Result<(), TransformError> {
134 let mut empty = BTreeMap::new();
135 let result = self.action(relation, &mut empty);
136 mz_repr::explain::trace_plan(&*relation);
137 result
138 }
139}
140
141impl PredicatePushdown {
142 /// Predicate pushdown
143 ///
144 /// This method looks for opportunities to push predicates toward
145 /// sources of data. Primarily, this is the `Filter` expression,
146 /// and moving its predicates through the operators it contains.
147 ///
148 /// In addition, the method accumulates the intersection of predicates
149 /// applied to each `Get` expression, so that the predicate can
150 /// then be pushed through to a `Let` binding, or to the external
151 /// source of the data if the `Get` binds to another view.
152 pub fn action(
153 &self,
154 relation: &mut MirRelationExpr,
155 get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
156 ) -> Result<(), TransformError> {
157 self.checked_recur(|_| {
158 // In the case of Filter or Get we have specific work to do;
159 // otherwise we should recursively descend.
160 match relation {
161 MirRelationExpr::Filter { input, predicates } => {
162 // Reduce the predicates to determine as best as possible
163 // whether they are literal errors before working with them.
164 let input_type = input.typ();
165 for predicate in predicates.iter_mut() {
166 predicate.reduce(&input_type.column_types);
167 }
168
169 // It can be helpful to know if there are any non-literal errors,
170 // as this is justification for not pushing down literal errors.
171 let all_errors = predicates.iter().all(|p| p.is_literal_err());
172 // Depending on the type of `input` we have different
173 // logic to apply to consider pushing `predicates` down.
174 match &mut **input {
175 MirRelationExpr::Let { body, .. }
176 | MirRelationExpr::LetRec { body, .. } => {
177 // Push all predicates to the body.
178 **body = body
179 .take_dangerous()
180 .filter(std::mem::replace(predicates, Vec::new()));
181
182 self.action(input, get_predicates)?;
183 }
184 MirRelationExpr::Get { id, .. } => {
185 // We can report the predicates upward in `get_predicates`,
186 // but we are not yet able to delete them from the
187 // `Filter`.
188 get_predicates
189 .entry(*id)
190 .or_insert_with(|| predicates.iter().cloned().collect())
191 .retain(|p| predicates.contains(p));
192 }
193 MirRelationExpr::Join {
194 inputs,
195 equivalences,
196 ..
197 } => {
198 // We want to scan `predicates` for any that can
199 // 1) become join variable constraints
200 // 2) apply to individual elements of `inputs`.
201 // Figuring out the set of predicates that belong to
202 // the latter group requires 1) knowing which predicates
203 // are in the former group and 2) that the variable
204 // constraints be in canonical form.
205 // Thus, there is a first scan across `predicates` to
206 // populate the join variable constraints
207 // and a second scan across the remaining predicates
208 // to see which ones can become individual elements of
209 // `inputs`.
210
211 let input_mapper = mz_expr::JoinInputMapper::new(inputs);
212
213 // Predicates not translated into join variable
214 // constraints. We will attempt to push them at all
215 // inputs, and failing to
216 let mut pred_not_translated = Vec::new();
217
218 for mut predicate in predicates.drain(..) {
219 use mz_expr::{BinaryFunc, UnaryFunc};
220 if let MirScalarExpr::CallBinary {
221 func: BinaryFunc::Eq,
222 expr1,
223 expr2,
224 } = &predicate
225 {
226 // Translate into join variable constraints:
227 // 1) `nonliteral1 == nonliteral2` constraints
228 // 2) `expr == literal` where `expr` refers to more
229 // than one input.
230 let input_count =
231 input_mapper.lookup_inputs(&predicate).count();
232 if (!expr1.is_literal() && !expr2.is_literal())
233 || input_count >= 2
234 {
235 // `col1 == col2` as a `MirScalarExpr`
236 // implies `!isnull(col1)` as well.
237 // `col1 == col2` as a join constraint does
238 // not have this extra implication.
239 // Thus, when translating the
240 // `MirScalarExpr` to a join constraint, we
241 // need to retain the `!isnull(col1)`
242 // information.
243 if expr1.typ(&input_type.column_types).nullable {
244 pred_not_translated.push(
245 expr1
246 .clone()
247 .call_unary(UnaryFunc::IsNull(func::IsNull))
248 .call_unary(UnaryFunc::Not(func::Not)),
249 );
250 } else if expr2.typ(&input_type.column_types).nullable {
251 pred_not_translated.push(
252 expr2
253 .clone()
254 .call_unary(UnaryFunc::IsNull(func::IsNull))
255 .call_unary(UnaryFunc::Not(func::Not)),
256 );
257 }
258 equivalences
259 .push(vec![(**expr1).clone(), (**expr2).clone()]);
260 continue;
261 }
262 } else if let Some((expr1, expr2)) =
263 Self::extract_equal_or_both_null(
264 &mut predicate,
265 &input_type.column_types,
266 )
267 {
268 // Also translate into join variable constraints:
269 // 3) `((nonliteral1 = nonliteral2) || (nonliteral
270 // is null && nonliteral2 is null))`
271 equivalences.push(vec![expr1, expr2]);
272 continue;
273 }
274 pred_not_translated.push(predicate)
275 }
276
277 mz_expr::canonicalize::canonicalize_equivalences(
278 equivalences,
279 std::iter::once(&input_type.column_types),
280 );
281
282 let (retain, push_downs) = Self::push_filters_through_join(
283 &input_mapper,
284 equivalences,
285 pred_not_translated,
286 );
287
288 Self::update_join_inputs_with_push_downs(inputs, push_downs);
289
290 // Recursively descend on the join
291 self.action(input, get_predicates)?;
292
293 // remove all predicates that were pushed down from the current Filter node
294 *predicates = retain;
295 }
296 MirRelationExpr::Reduce {
297 input: inner,
298 group_key,
299 aggregates,
300 monotonic: _,
301 expected_group_size: _,
302 } => {
303 let mut retain = Vec::new();
304 let mut push_down = Vec::new();
305 for predicate in predicates.drain(..) {
306 // Do not push down literal errors unless it is only errors.
307 if !predicate.is_literal_err() || all_errors {
308 let mut supported = true;
309 let mut new_predicate = predicate.clone();
310 new_predicate.visit_pre(|e| {
311 if let MirScalarExpr::Column(c) = e {
312 if *c >= group_key.len() {
313 supported = false;
314 }
315 }
316 });
317 if supported {
318 new_predicate.visit_mut_post(&mut |e| {
319 if let MirScalarExpr::Column(i) = e {
320 *e = group_key[*i].clone();
321 }
322 })?;
323 push_down.push(new_predicate);
324 } else if let MirScalarExpr::Column(col) = &predicate {
325 if *col == group_key.len()
326 && aggregates.len() == 1
327 && aggregates[0].func == AggregateFunc::Any
328 {
329 push_down.push(aggregates[0].expr.clone());
330 aggregates[0].expr = MirScalarExpr::literal_ok(
331 Datum::True,
332 ScalarType::Bool,
333 );
334 } else {
335 retain.push(predicate);
336 }
337 } else {
338 retain.push(predicate);
339 }
340 } else {
341 retain.push(predicate);
342 }
343 }
344
345 if !push_down.is_empty() {
346 *inner = Box::new(inner.take_dangerous().filter(push_down));
347 }
348 self.action(inner, get_predicates)?;
349
350 // remove all predicates that were pushed down from the current Filter node
351 std::mem::swap(&mut retain, predicates);
352 }
353 MirRelationExpr::TopK {
354 input,
355 group_key,
356 order_key: _,
357 limit,
358 offset: _,
359 monotonic: _,
360 expected_group_size: _,
361 } => {
362 let mut retain = Vec::new();
363 let mut push_down = Vec::new();
364
365 let mut support = BTreeSet::new();
366 support.extend(group_key.iter().cloned());
367 if let Some(limit) = limit {
368 // Strictly speaking not needed because the
369 // `limit` support should be a subset of the
370 // `group_key` support, but we don't want to
371 // take this for granted here.
372 limit.support_into(&mut support);
373 }
374
375 for predicate in predicates.drain(..) {
376 // Do not push down literal errors unless it is only errors.
377 if (!predicate.is_literal_err() || all_errors)
378 && predicate.support().is_subset(&support)
379 {
380 push_down.push(predicate);
381 } else {
382 retain.push(predicate);
383 }
384 }
385
386 // remove all predicates that were pushed down from the current Filter node
387 std::mem::swap(&mut retain, predicates);
388
389 if !push_down.is_empty() {
390 *input = Box::new(input.take_dangerous().filter(push_down));
391 }
392
393 self.action(input, get_predicates)?;
394 }
395 MirRelationExpr::Threshold { input } => {
396 let predicates = std::mem::take(predicates);
397 *relation = input.take_dangerous().filter(predicates).threshold();
398 self.action(relation, get_predicates)?;
399 }
400 MirRelationExpr::Project { input, outputs } => {
401 let predicates = predicates.drain(..).map(|mut predicate| {
402 predicate.permute(outputs);
403 predicate
404 });
405 *relation = input
406 .take_dangerous()
407 .filter(predicates)
408 .project(outputs.clone());
409
410 self.action(relation, get_predicates)?;
411 }
412 MirRelationExpr::Filter {
413 input,
414 predicates: predicates2,
415 } => {
416 *relation = input
417 .take_dangerous()
418 .filter(predicates.clone().into_iter().chain(predicates2.clone()));
419 self.action(relation, get_predicates)?;
420 }
421 MirRelationExpr::Map { input, scalars } => {
422 let (retained, pushdown) = Self::push_filters_through_map(
423 scalars,
424 predicates,
425 input.arity(),
426 all_errors,
427 )?;
428 let scalars = std::mem::take(scalars);
429 let mut result = input.take_dangerous();
430 if !pushdown.is_empty() {
431 result = result.filter(pushdown);
432 }
433 self.action(&mut result, get_predicates)?;
434 result = result.map(scalars);
435 if !retained.is_empty() {
436 result = result.filter(retained);
437 }
438 *relation = result;
439 }
440 MirRelationExpr::FlatMap { input, .. } => {
441 let (mut retained, pushdown) =
442 Self::push_filters_through_flat_map(predicates, input.arity());
443
444 // remove all predicates that were pushed down from the current Filter node
445 std::mem::swap(&mut retained, predicates);
446
447 if !pushdown.is_empty() {
448 // put the filter on top of the input
449 **input = input.take_dangerous().filter(pushdown);
450 }
451
452 // ... and keep pushing predicates down
453 self.action(input, get_predicates)?;
454 }
455 MirRelationExpr::Union { base, inputs } => {
456 let predicates = std::mem::take(predicates);
457 *base = Box::new(base.take_dangerous().filter(predicates.clone()));
458 self.action(base, get_predicates)?;
459 for input in inputs {
460 *input = input.take_dangerous().filter(predicates.clone());
461 self.action(input, get_predicates)?;
462 }
463 }
464 MirRelationExpr::Negate { input } => {
465 // Don't push literal errors past a Negate. The problem is that it's
466 // hard to appropriately reflect the negation in the error stream:
467 // - If we don't negate, then errors that should cancel out will not
468 // cancel out. For example, see
469 // https://github.com/MaterializeInc/database-issues/issues/5691
470 // - If we negate, then unrelated errors might cancel out. E.g., there
471 // might be a division-by-0 in both inputs to an EXCEPT ALL, but
472 // on different input data. These shouldn't cancel out.
473 let (retained, pushdown): (Vec<_>, Vec<_>) = std::mem::take(predicates)
474 .into_iter()
475 .partition(|p| p.is_literal_err());
476 let mut result = input.take_dangerous();
477 if !pushdown.is_empty() {
478 result = result.filter(pushdown);
479 }
480 self.action(&mut result, get_predicates)?;
481 result = result.negate();
482 if !retained.is_empty() {
483 result = result.filter(retained);
484 }
485 *relation = result;
486 }
487 x => {
488 x.try_visit_mut_children(|e| self.action(e, get_predicates))?;
489 }
490 }
491
492 // remove empty filters (junk by-product of the actual transform)
493 match relation {
494 MirRelationExpr::Filter { predicates, input } if predicates.is_empty() => {
495 *relation = input.take_dangerous();
496 }
497 _ => {}
498 }
499
500 Ok(())
501 }
502 MirRelationExpr::Get { id, .. } => {
503 // Purge all predicates associated with the id.
504 get_predicates
505 .entry(*id)
506 .or_insert_with(BTreeSet::new)
507 .clear();
508
509 Ok(())
510 }
511 MirRelationExpr::Let { id, body, value } => {
512 // Push predicates and collect intersection at `Get`s.
513 self.action(body, get_predicates)?;
514
515 // `get_predicates` should now contain the intersection
516 // of predicates at each *use* of the binding. If it is
517 // non-empty, we can move those predicates to the value.
518 Self::push_into_let_binding(get_predicates, id, value, &mut [body]);
519
520 // Continue recursively on the value.
521 self.action(value, get_predicates)
522 }
523 MirRelationExpr::LetRec {
524 ids,
525 values,
526 limits: _,
527 body,
528 } => {
529 // Note: This could be extended to be able to do a little more pushdowns, see
530 // https://github.com/MaterializeInc/database-issues/issues/5336#issuecomment-1477588262
531
532 // Pre-compute which Ids are used across iterations
533 let ids_used_across_iterations = MirRelationExpr::recursive_ids(ids, values);
534
535 // Predicate pushdown within the body
536 self.action(body, get_predicates)?;
537
538 // `users` will be the body plus the values of those bindings that we have seen
539 // so far, while going one-by-one through the list of bindings backwards.
540 // `users` contains those expressions from which we harvested `get_predicates`,
541 // and therefore we should attend to all of these expressions when pushing down
542 // a predicate into a Let binding.
543 let mut users = vec![&mut **body];
544 for (id, value) in ids.iter_mut().zip(values).rev() {
545 // Predicate pushdown from Gets in `users` into the value of a Let binding
546 //
547 // For now, we simply always avoid pushing into a Let binding that is
548 // referenced across iterations to avoid soundness problems and infinite
549 // pushdowns.
550 //
551 // Note that `push_into_let_binding` makes a further check based on
552 // `get_predicates`: We push a predicate into the value of a binding, only
553 // if all Gets of this Id have this same predicate on top of them.
554 if !ids_used_across_iterations.contains(id) {
555 Self::push_into_let_binding(get_predicates, id, value, &mut users);
556 }
557
558 // Predicate pushdown within a binding
559 self.action(value, get_predicates)?;
560
561 users.push(value);
562 }
563
564 Ok(())
565 }
566 MirRelationExpr::Join {
567 inputs,
568 equivalences,
569 ..
570 } => {
571 // The goal is to push
572 // 1) equivalences of the form `expr = <runtime constant>`, where `expr`
573 // comes from a single input.
574 // 2) equivalences of the form `expr1 = expr2`, where both
575 // expressions come from the same single input.
576 let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
577 mz_expr::canonicalize::canonicalize_equivalences(
578 equivalences,
579 input_types.iter().map(|t| &t.column_types),
580 );
581
582 let input_mapper = mz_expr::JoinInputMapper::new_from_input_types(&input_types);
583 // Predicates to push at each input, and to lift out the join.
584 let mut push_downs = vec![Vec::new(); inputs.len()];
585
586 for equivalence_pos in 0..equivalences.len() {
587 // Case 1: there are more than one literal in the
588 // equivalence class. Because of equivalences have been
589 // dedupped, this means that everything in the equivalence
590 // class must be equal to two different literals, so the
591 // entire relation zeroes out
592 if equivalences[equivalence_pos]
593 .iter()
594 .filter(|expr| expr.is_literal())
595 .count()
596 > 1
597 {
598 relation.take_safely(Some(relation.typ_with_input_types(&input_types)));
599 return Ok(());
600 }
601
602 let runtime_constants = equivalences[equivalence_pos]
603 .iter()
604 .filter(|expr| expr.support().is_empty())
605 .cloned()
606 .collect::<Vec<_>>();
607 if !runtime_constants.is_empty() {
608 // Case 2: There is at least one runtime constant the equivalence class
609 let gen_literal_equality_preds = |expr: MirScalarExpr| {
610 let mut equality_preds = Vec::new();
611 for constant in runtime_constants.iter() {
612 let pred = if constant.is_literal_null() {
613 MirScalarExpr::CallUnary {
614 func: mz_expr::UnaryFunc::IsNull(func::IsNull),
615 expr: Box::new(expr.clone()),
616 }
617 } else {
618 MirScalarExpr::CallBinary {
619 func: mz_expr::BinaryFunc::Eq,
620 expr1: Box::new(expr.clone()),
621 expr2: Box::new(constant.clone()),
622 }
623 };
624 equality_preds.push(pred);
625 }
626 equality_preds
627 };
628
629 // Find all single input expressions in the equivalence
630 // class and collect (position within the equivalence class,
631 // input the expression belongs to, localized version of the
632 // expression).
633 let mut single_input_exprs = equivalences[equivalence_pos]
634 .iter()
635 .enumerate()
636 .filter_map(|(pos, e)| {
637 let mut inputs = input_mapper.lookup_inputs(e);
638 if let Some(input) = inputs.next() {
639 if inputs.next().is_none() {
640 return Some((
641 pos,
642 input,
643 input_mapper.map_expr_to_local(e.clone()),
644 ));
645 }
646 }
647 None
648 })
649 .collect::<Vec<_>>();
650
651 // For every single-input expression `expr`, we can push
652 // down `expr = <runtime constant>` and remove `expr` from the
653 // equivalence class.
654 for (expr_pos, input, expr) in single_input_exprs.drain(..).rev() {
655 push_downs[input].extend(gen_literal_equality_preds(expr));
656 equivalences[equivalence_pos].remove(expr_pos);
657 }
658
659 // If none of the expressions in the equivalence depend on input
660 // columns and equality predicates with them are pushed down,
661 // we can safely remove them from the equivalence.
662 // TODO: we could probably push equality predicates among the
663 // remaining constants to all join inputs to prevent any computation
664 // from happening until the condition is satisfied.
665 if equivalences[equivalence_pos]
666 .iter()
667 .all(|e| e.support().is_empty())
668 && push_downs.iter().any(|p| !p.is_empty())
669 {
670 equivalences[equivalence_pos].clear();
671 }
672 } else {
673 // Case 3: There are no constants in the equivalence
674 // class. Push a predicate for every pair of expressions
675 // in the equivalence that either belong to a single
676 // input or can be localized to a given input through
677 // the rest of equivalences.
678 let mut to_remove = Vec::new();
679 for input in 0..inputs.len() {
680 // Vector of pairs (position within the equivalence, localized
681 // expression). The position is None for expressions derived through
682 // other equivalences.
683 let localized = equivalences[equivalence_pos]
684 .iter()
685 .enumerate()
686 .filter_map(|(pos, expr)| {
687 if let MirScalarExpr::Column(col_pos) = &expr {
688 let local_col =
689 input_mapper.map_column_to_local(*col_pos);
690 if input == local_col.1 {
691 return Some((
692 Some(pos),
693 MirScalarExpr::Column(local_col.0),
694 ));
695 } else {
696 return None;
697 }
698 }
699 let mut inputs = input_mapper.lookup_inputs(expr);
700 if let Some(single_input) = inputs.next() {
701 if input == single_input && inputs.next().is_none() {
702 return Some((
703 Some(pos),
704 input_mapper.map_expr_to_local(expr.clone()),
705 ));
706 }
707 }
708 // Equivalences not including the current expression
709 let mut other_equivalences = equivalences.clone();
710 other_equivalences[equivalence_pos].remove(pos);
711 let mut localized = expr.clone();
712 if input_mapper.try_localize_to_input_with_bound_expr(
713 &mut localized,
714 input,
715 &other_equivalences[..],
716 ) {
717 Some((None, localized))
718 } else {
719 None
720 }
721 })
722 .collect::<Vec<_>>();
723
724 // If there are at least 2 expression in the equivalence that
725 // can be localized to the same input, push all combinations
726 // of them to the input.
727 if localized.len() > 1 {
728 for mut pair in
729 localized.iter().map(|(_, expr)| expr).combinations(2)
730 {
731 let expr1 = pair.pop().unwrap();
732 let expr2 = pair.pop().unwrap();
733
734 push_downs[input].push(
735 MirScalarExpr::CallBinary {
736 func: mz_expr::BinaryFunc::Eq,
737 expr1: Box::new(expr2.clone()),
738 expr2: Box::new(expr1.clone()),
739 }
740 .or(expr2
741 .clone()
742 .call_is_null()
743 .and(expr1.clone().call_is_null())),
744 );
745 }
746
747 if localized.len() == equivalences[equivalence_pos].len() {
748 // The equivalence is either a single input one or fully localizable
749 // to a single input through other equivalences, so it can be removed
750 // completely without introducing any new cross join.
751 to_remove.extend(0..equivalences[equivalence_pos].len());
752 } else {
753 // Leave an expression from this input in the equivalence to avoid
754 // cross joins
755 to_remove.extend(
756 localized.iter().filter_map(|(pos, _)| *pos).skip(1),
757 );
758 }
759 }
760 }
761
762 // Remove expressions that were pushed down to at least one input
763 to_remove.sort();
764 to_remove.dedup();
765 for pos in to_remove.iter().rev() {
766 equivalences[equivalence_pos].remove(*pos);
767 }
768 };
769 }
770
771 mz_expr::canonicalize::canonicalize_equivalences(
772 equivalences,
773 input_types.iter().map(|t| &t.column_types),
774 );
775
776 Self::update_join_inputs_with_push_downs(inputs, push_downs);
777
778 // Recursively descend on each of the inputs.
779 for input in inputs.iter_mut() {
780 self.action(input, get_predicates)?;
781 }
782
783 Ok(())
784 }
785 x => {
786 // Recursively descend.
787 x.try_visit_mut_children(|e| self.action(e, get_predicates))
788 }
789 }
790 })
791 }
792
793 fn update_join_inputs_with_push_downs(
794 inputs: &mut Vec<MirRelationExpr>,
795 push_downs: Vec<Vec<MirScalarExpr>>,
796 ) {
797 let new_inputs = inputs
798 .drain(..)
799 .zip(push_downs)
800 .map(|(input, push_down)| {
801 if !push_down.is_empty() {
802 input.filter(push_down)
803 } else {
804 input
805 }
806 })
807 .collect();
808 *inputs = new_inputs;
809 }
810
811 // Checks `get_predicates` to see whether we can push a predicate into the Let binding given
812 // by `id` and `value`.
813 // `users` is the list of those expressions from which we will need to remove a predicate that
814 // is being pushed.
815 fn push_into_let_binding(
816 get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
817 id: &LocalId,
818 value: &mut MirRelationExpr,
819 users: &mut [&mut MirRelationExpr],
820 ) {
821 if let Some(list) = get_predicates.remove(&Id::Local(*id)) {
822 if !list.is_empty() {
823 // Remove the predicates in `list` from the users.
824 for user in users {
825 user.visit_pre_mut(|e| {
826 if let MirRelationExpr::Filter { input, predicates } = e {
827 if let MirRelationExpr::Get { id: get_id, .. } = **input {
828 if get_id == Id::Local(*id) {
829 predicates.retain(|p| !list.contains(p));
830 }
831 }
832 }
833 });
834 }
835 // Apply the predicates in `list` to value. Canonicalize
836 // `list` so that plans are always deterministic.
837 let mut list = list.into_iter().collect::<Vec<_>>();
838 mz_expr::canonicalize::canonicalize_predicates(
839 &mut list,
840 &value.typ().column_types,
841 );
842 *value = value.take_dangerous().filter(list);
843 }
844 }
845 }
846
847 /// Returns `(<predicates to retain>, <predicates to push at each input>)`.
848 pub fn push_filters_through_join(
849 input_mapper: &JoinInputMapper,
850 equivalences: &Vec<Vec<MirScalarExpr>>,
851 mut predicates: Vec<MirScalarExpr>,
852 ) -> (Vec<MirScalarExpr>, Vec<Vec<MirScalarExpr>>) {
853 let mut push_downs = vec![Vec::new(); input_mapper.total_inputs()];
854 let mut retain = Vec::new();
855
856 for predicate in predicates.drain(..) {
857 // Track if the predicate has been pushed to at least one input.
858 let mut pushed = false;
859 // For each input, try and see if the join
860 // equivalences allow the predicate to be rewritten
861 // in terms of only columns from that input.
862 for (index, push_down) in push_downs.iter_mut().enumerate() {
863 #[allow(deprecated)] // TODO: use `might_error` if possible.
864 if predicate.is_literal_err() || predicate.contains_error_if_null() {
865 // Do nothing. We don't push down literal errors,
866 // as we can't know the join will be non-empty.
867 //
868 // We also don't want to push anything that involves `error_if_null`. This is
869 // for the same reason why in theory we shouldn't really push anything that can
870 // error, assuming that we want to preserve error semantics. (Because we would
871 // create a spurious error if some other Join input ends up empty.) We can't fix
872 // this problem in general (as we can't just not push anything that might
873 // error), but we decided to fix the specific problem instance involving
874 // `error_if_null`, because it was very painful:
875 // <https://github.com/MaterializeInc/database-issues/issues/6258>
876 } else {
877 let mut localized = predicate.clone();
878 if input_mapper.try_localize_to_input_with_bound_expr(
879 &mut localized,
880 index,
881 equivalences,
882 ) {
883 push_down.push(localized);
884 pushed = true;
885 } else if let Some(consequence) = input_mapper
886 // (`consequence_for_input` assumes that
887 // `try_localize_to_input_with_bound_expr` has already
888 // been called on `localized`.)
889 .consequence_for_input(&localized, index)
890 {
891 push_down.push(consequence);
892 // We don't set `pushed` here! We want to retain the
893 // predicate, because we only pushed a consequence of
894 // it, but not the full predicate.
895 }
896 }
897 }
898
899 if !pushed {
900 retain.push(predicate);
901 }
902 }
903
904 (retain, push_downs)
905 }
906
907 /// Computes "safe" predicates to push through a Map.
908 ///
909 /// In the case of a Filter { Map {...} }, we can always push down the Filter
910 /// by inlining expressions from the Map. We don't want to do this in general,
911 /// however, since general inlining can result in exponential blowup in the size
912 /// of expressions, so we only do this in the case where the size after inlining
913 /// is below a certain limit.
914 ///
915 /// Returns the predicates that can be pushed down, followed by ones that cannot.
916 pub fn push_filters_through_map(
917 map_exprs: &Vec<MirScalarExpr>,
918 predicates: &mut Vec<MirScalarExpr>,
919 input_arity: usize,
920 all_errors: bool,
921 ) -> Result<(Vec<MirScalarExpr>, Vec<MirScalarExpr>), TransformError> {
922 let mut pushdown = Vec::new();
923 let mut retained = Vec::new();
924 for predicate in predicates.drain(..) {
925 // We don't push down literal errors, unless all predicates are.
926 if !predicate.is_literal_err() || all_errors {
927 // Consider inlining Map expressions.
928 if let Some(cleaned) =
929 Self::inline_if_not_too_big(&predicate, input_arity, map_exprs)?
930 {
931 pushdown.push(cleaned);
932 } else {
933 retained.push(predicate);
934 }
935 } else {
936 retained.push(predicate);
937 }
938 }
939 Ok((retained, pushdown))
940 }
941
942 /// This fn should be called with a Filter `expr` that is after a Map. `input_arity` is the
943 /// arity of the input of the Map. This fn eliminates such column refs in `expr` that refer not
944 /// to a column in the input of the Map, but to a column that is created by the Map. It does
945 /// this by transitively inlining Map expressions until no such expression remains that points
946 /// to a Map expression. The return value is the cleaned up expression. The fn bails out with a
947 /// None if the resulting expression would be made too big by the inlinings.
948 ///
949 /// OOO: (Optimizer Optimization Opportunity) This function might do work proportional to the
950 /// total size of the Map expressions. We call this function for each predicate above the Map,
951 /// which will be kind of quadratic, i.e., if there are many predicates and a big Map, then this
952 /// will be slow. We could instead pass a vector of Map expressions and call this fn only once.
953 /// The only downside would be that then the inlining limit being hit in the middle part of this
954 /// function would prevent us from inlining any predicates, even ones that wouldn't hit the
955 /// inlining limit if considered on their own.
956 fn inline_if_not_too_big(
957 expr: &MirScalarExpr,
958 input_arity: usize,
959 map_exprs: &Vec<MirScalarExpr>,
960 ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
961 let size_limit = 1000;
962
963 // Transitively determine the support of `expr` produced by `map_exprs`
964 // that needs to be inlined.
965 let cols_to_inline = {
966 let mut support = BTreeSet::new();
967
968 // Seed with `map_exprs` support in `expr`.
969 expr.visit_pre(|e| {
970 if let MirScalarExpr::Column(c) = e {
971 if *c >= input_arity {
972 support.insert(*c);
973 }
974 }
975 });
976
977 // Compute transitive closure of supports in `map_exprs`.
978 let mut workset = support.iter().cloned().collect::<Vec<_>>();
979 let mut buffer = vec![];
980 while !workset.is_empty() {
981 // Swap the (empty) `drained` buffer with the `workset`.
982 std::mem::swap(&mut workset, &mut buffer);
983 // Drain the `buffer` and update `support` and `workset`.
984 for c in buffer.drain(..) {
985 map_exprs[c - input_arity].visit_pre(|e| {
986 if let MirScalarExpr::Column(c) = e {
987 if *c >= input_arity {
988 if support.insert(*c) {
989 workset.push(*c);
990 }
991 }
992 }
993 });
994 }
995 }
996 support
997 };
998
999 let mut inlined = BTreeMap::<usize, (MirScalarExpr, usize)>::new();
1000 // Populate the memo table in ascending column order (which respects the
1001 // dependency order of `map_exprs` references). Break early if memoization
1002 // fails for one of the columns in `cols_to_inline`.
1003 for c in cols_to_inline.iter() {
1004 let mut new_expr = map_exprs[*c - input_arity].clone();
1005 let mut new_size = 0;
1006 new_expr.visit_mut_post(&mut |expr| {
1007 new_size += 1;
1008 if let MirScalarExpr::Column(c) = expr {
1009 if *c >= input_arity && new_size <= size_limit {
1010 // (inlined[c] is safe, because we proceed in column order, and we break out
1011 // of the loop when we stop inserting into memo.)
1012 let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1013 *expr = m_expr.clone();
1014 new_size += m_size - 1; // Adjust for the +1 above.
1015 }
1016 }
1017 })?;
1018
1019 if new_size <= size_limit {
1020 inlined.insert(*c, (new_expr, new_size));
1021 } else {
1022 break;
1023 }
1024 }
1025
1026 // Try to resolve expr against the memo table.
1027 if inlined.len() < cols_to_inline.len() {
1028 Ok(None) // We couldn't memoize all map expressions within the given limit.
1029 } else {
1030 let mut new_expr = expr.clone();
1031 let mut new_size = 0;
1032 new_expr.visit_mut_post(&mut |expr| {
1033 new_size += 1;
1034 if let MirScalarExpr::Column(c) = expr {
1035 if *c >= input_arity && new_size <= size_limit {
1036 // (inlined[c] is safe because of the outer if condition.)
1037 let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1038 *expr = m_expr.clone();
1039 new_size += m_size - 1; // Adjust for the +1 above.
1040 }
1041 }
1042 })?;
1043
1044 soft_assert_eq_no_log!(new_size, new_expr.size());
1045 if new_size <= size_limit {
1046 Ok(Some(new_expr)) // We managed to stay within the limit.
1047 } else {
1048 Ok(None) // Limit exceeded.
1049 }
1050 }
1051 }
1052 // fn inline_if_not_too_big(
1053 // expr: &MirScalarExpr,
1054 // input_arity: usize,
1055 // map_exprs: &Vec<MirScalarExpr>,
1056 // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1057 // let size_limit = 1000;
1058 // // Memoize cleaned up versions of Map expressions. (Not necessarily all the Map expressions
1059 // // will be involved.)
1060 // let mut memo: BTreeMap<MirScalarExpr, MirScalarExpr> = BTreeMap::new();
1061 // fn rec(
1062 // expr: &MirScalarExpr,
1063 // input_arity: usize,
1064 // map_exprs: &Vec<MirScalarExpr>,
1065 // memo: &mut BTreeMap<MirScalarExpr, MirScalarExpr>,
1066 // size_limit: usize,
1067 // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1068 // // (We can't use Entry::or_insert_with, because the closure would need to be fallible.
1069 // // We also can't manually match on the result of memo.entry, because that holds a
1070 // // borrow of memo, but we need to pass memo to the recursive call in the middle.)
1071 // match memo.get(expr) {
1072 // Some(memoized_result) => Ok(Some(memoized_result.clone())),
1073 // None => {
1074 // let mut expr_size = expr.size()?;
1075 // let mut cleaned_expr = expr.clone();
1076 // let mut bail = false;
1077 // cleaned_expr.try_visit_mut_post(&mut |expr| {
1078 // Ok(if !bail {
1079 // match expr {
1080 // MirScalarExpr::Column(col) => {
1081 // if *col >= input_arity {
1082 // let to_inline = rec(
1083 // &map_exprs[*col - input_arity],
1084 // input_arity,
1085 // map_exprs,
1086 // memo,
1087 // size_limit,
1088 // )?;
1089 // if let Some(to_inline) = to_inline {
1090 // // The `-1` is because the expression that we are
1091 // // replacing has a size of 1.
1092 // expr_size += to_inline.size()? - 1;
1093 // *expr = to_inline;
1094 // if expr_size > size_limit {
1095 // bail = true;
1096 // }
1097 // } else {
1098 // bail = true;
1099 // }
1100 // }
1101 // }
1102 // _ => (),
1103 // }
1104 // })
1105 // })?;
1106 // soft_assert_eq!(cleaned_expr.size()?, expr_size);
1107 // if !bail {
1108 // memo.insert(expr.clone(), cleaned_expr.clone());
1109 // Ok(Some(cleaned_expr))
1110 // } else {
1111 // Ok(None)
1112 // }
1113 // }
1114 // }
1115 // }
1116 // rec(expr, input_arity, map_exprs, &mut memo, size_limit)
1117 // }
1118
1119 /// Computes "safe" predicate to push through a FlatMap.
1120 ///
1121 /// In the case of a Filter { FlatMap {...} }, we want to push through all predicates
1122 /// that (1) are not literal errors and (2) have support exclusively in the columns
1123 /// provided by the FlatMap input.
1124 ///
1125 /// Returns the predicates that can be pushed down, followed by ones that cannot.
1126 fn push_filters_through_flat_map(
1127 predicates: &mut Vec<MirScalarExpr>,
1128 input_arity: usize,
1129 ) -> (Vec<MirScalarExpr>, Vec<MirScalarExpr>) {
1130 let mut pushdown = Vec::new();
1131 let mut retained = Vec::new();
1132 for predicate in predicates.drain(..) {
1133 // First, check if we can push this predicate down. We can do so if and only if:
1134 // (1) the predicate is not a literal error, and
1135 // (2) each column it references is from the input.
1136 if (!predicate.is_literal_err()) && predicate.support().iter().all(|c| *c < input_arity)
1137 {
1138 pushdown.push(predicate);
1139 } else {
1140 retained.push(predicate);
1141 }
1142 }
1143 (retained, pushdown)
1144 }
1145
1146 /// If `s` is of the form
1147 /// `(isnull(expr1) && isnull(expr2)) || (expr1 = expr2)`, or
1148 /// `(decompose_is_null(expr1) && decompose_is_null(expr2)) || (expr1 = expr2)`,
1149 /// extract `expr1` and `expr2`.
1150 fn extract_equal_or_both_null(
1151 s: &mut MirScalarExpr,
1152 column_types: &[ColumnType],
1153 ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1154 if let MirScalarExpr::CallVariadic {
1155 func: VariadicFunc::Or,
1156 exprs,
1157 } = s
1158 {
1159 if let &[ref or_lhs, ref or_rhs] = &**exprs {
1160 // Check both orders of operands of the OR
1161 return Self::extract_equal_or_both_null_inner(or_lhs, or_rhs, column_types)
1162 .or_else(|| {
1163 Self::extract_equal_or_both_null_inner(or_rhs, or_lhs, column_types)
1164 });
1165 }
1166 }
1167 None
1168 }
1169
1170 fn extract_equal_or_both_null_inner(
1171 or_arg1: &MirScalarExpr,
1172 or_arg2: &MirScalarExpr,
1173 column_types: &[ColumnType],
1174 ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1175 use mz_expr::BinaryFunc;
1176 if let MirScalarExpr::CallBinary {
1177 func: BinaryFunc::Eq,
1178 expr1: eq_lhs,
1179 expr2: eq_rhs,
1180 } = &or_arg2
1181 {
1182 let isnull1 = eq_lhs.clone().call_is_null();
1183 let isnull2 = eq_rhs.clone().call_is_null();
1184 let both_null = MirScalarExpr::CallVariadic {
1185 func: VariadicFunc::And,
1186 exprs: vec![isnull1, isnull2],
1187 };
1188
1189 if Self::extract_reduced_conjunction_terms(both_null, column_types)
1190 == Self::extract_reduced_conjunction_terms(or_arg1.clone(), column_types)
1191 {
1192 return Some(((**eq_lhs).clone(), (**eq_rhs).clone()));
1193 }
1194 }
1195 None
1196 }
1197
1198 /// Reduces the given expression and returns its AND-ed terms.
1199 fn extract_reduced_conjunction_terms(
1200 mut s: MirScalarExpr,
1201 column_types: &[ColumnType],
1202 ) -> Vec<MirScalarExpr> {
1203 s.reduce(column_types);
1204
1205 if let MirScalarExpr::CallVariadic {
1206 func: VariadicFunc::And,
1207 exprs,
1208 } = s
1209 {
1210 exprs
1211 } else {
1212 vec![s]
1213 }
1214 }
1215}