mz_transform/predicate_pushdown.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes predicates down through other operators.
11//!
12//! This action generally improves the quality of the query, in that selective per-record
13//! filters reduce the volume of data before they arrive at more expensive operators.
14//!
15//!
16//! The one time when this action might not improve the quality of a query is
17//! if a filter gets pushed down on an arrangement because that blocks arrangement
18//! reuse. It assumed that actions that need an arrangement are responsible for
19//! lifting filters out of the way.
20//!
21//! Predicate pushdown will not push down literal errors, unless it is certain that
22//! the literal errors will be unconditionally evaluated. For example, the pushdown
23//! will not happen if not all predicates can be pushed down (e.g. reduce and map),
24//! or if we are not certain that the input is non-empty (e.g. join).
25//! Note that this is not addressing the problem in its full generality, because this problem can
26//! occur with any function call that might error (although much more rarely than with literal
27//! errors). See <https://github.com/MaterializeInc/database-issues/issues/4972#issuecomment-1547391011>
28//!
29//! ```rust
30//! use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr, func};
31//! use mz_ore::id_gen::IdGen;
32//! use mz_repr::{ReprColumnType, ReprRelationType, ReprScalarType};
33//! use mz_repr::optimize::OptimizerFeatures;
34//! use mz_transform::{reprtypecheck, Transform, TransformCtx};
35//! use mz_transform::dataflow::DataflowMetainfo;
36//!
37//! use mz_transform::predicate_pushdown::PredicatePushdown;
38//!
39//! let bool_typ = ReprRelationType::new(vec![
40//! ReprColumnType { scalar_type: ReprScalarType::Bool, nullable: false },
41//! ]);
42//! let input1 = MirRelationExpr::Constant {
43//! rows: Ok(vec![]),
44//! typ: bool_typ.clone(),
45//! };
46//! let input2 = MirRelationExpr::Constant {
47//! rows: Ok(vec![]),
48//! typ: bool_typ.clone(),
49//! };
50//! let input3 = MirRelationExpr::Constant {
51//! rows: Ok(vec![]),
52//! typ: bool_typ,
53//! };
54//! let join = MirRelationExpr::join(
55//! vec![input1.clone(), input2.clone(), input3.clone()],
56//! vec![vec![(0, 0), (2, 0)].into_iter().collect()],
57//! );
58//!
59//! let predicate0 = MirScalarExpr::column(0);
60//! let predicate1 = MirScalarExpr::column(1);
61//! let predicate01 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(2), func::AddInt64);
62//! let predicate012 = MirScalarExpr::literal_false();
63//!
64//! let mut expr = join.filter(
65//! vec![
66//! predicate0.clone(),
67//! predicate1.clone(),
68//! predicate01.clone(),
69//! predicate012.clone(),
70//! ]);
71//!
72//! let features = OptimizerFeatures::default();
73//! let typecheck_ctx = reprtypecheck::empty_context();
74//! let mut df_meta = DataflowMetainfo::default();
75//! let mut transform_ctx = TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None, None);
76//!
77//! PredicatePushdown::default().transform(&mut expr, &mut transform_ctx);
78//!
79//! let predicate00 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(0), func::AddInt64);
80//! let expected_expr = MirRelationExpr::join(
81//! vec![
82//! input1.clone().filter(vec![predicate0.clone(), predicate00.clone()]),
83//! input2.clone().filter(vec![predicate0.clone()]),
84//! input3.clone().filter(vec![predicate0, predicate00])
85//! ],
86//! vec![vec![(0, 0), (2, 0)].into_iter().collect()],
87//! ).filter(vec![predicate012]);
88//! assert_eq!(expected_expr, expr)
89//! ```
90
91use std::collections::{BTreeMap, BTreeSet};
92
93use itertools::Itertools;
94use mz_expr::func::variadic::And;
95use mz_expr::visit::{Visit, VisitChildren};
96use mz_expr::{
97 AggregateFunc, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
98 VariadicFunc, func,
99};
100use mz_ore::soft_assert_eq_no_log;
101use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
102use mz_repr::{Datum, ReprColumnType, ReprScalarType};
103
104use crate::{TransformCtx, TransformError};
105
106/// Pushes predicates down through other operators.
107#[derive(Debug)]
108pub struct PredicatePushdown {
109 recursion_guard: RecursionGuard,
110}
111
112impl Default for PredicatePushdown {
113 fn default() -> PredicatePushdown {
114 PredicatePushdown {
115 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
116 }
117 }
118}
119
120impl CheckedRecursion for PredicatePushdown {
121 fn recursion_guard(&self) -> &RecursionGuard {
122 &self.recursion_guard
123 }
124}
125
126impl crate::Transform for PredicatePushdown {
127 fn name(&self) -> &'static str {
128 "PredicatePushdown"
129 }
130
131 #[mz_ore::instrument(
132 target = "optimizer",
133 level = "debug",
134 fields(path.segment = "predicate_pushdown")
135 )]
136 fn actually_perform_transform(
137 &self,
138 relation: &mut MirRelationExpr,
139 _: &mut TransformCtx,
140 ) -> Result<(), TransformError> {
141 let mut empty = BTreeMap::new();
142 let result = self.action(relation, &mut empty);
143 mz_repr::explain::trace_plan(&*relation);
144 result
145 }
146}
147
148impl PredicatePushdown {
149 /// Predicate pushdown
150 ///
151 /// This method looks for opportunities to push predicates toward
152 /// sources of data. Primarily, this is the `Filter` expression,
153 /// and moving its predicates through the operators it contains.
154 ///
155 /// In addition, the method accumulates the intersection of predicates
156 /// applied to each `Get` expression, so that the predicate can
157 /// then be pushed through to a `Let` binding, or to the external
158 /// source of the data if the `Get` binds to another view.
159 pub fn action(
160 &self,
161 relation: &mut MirRelationExpr,
162 get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
163 ) -> Result<(), TransformError> {
164 self.checked_recur(|_| {
165 // In the case of Filter or Get we have specific work to do;
166 // otherwise we should recursively descend.
167 match relation {
168 MirRelationExpr::Filter { input, predicates } => {
169 // Reduce the predicates to determine as best as possible
170 // whether they are literal errors before working with them.
171 let input_type = input.typ();
172 for predicate in predicates.iter_mut() {
173 predicate.reduce(&input_type.column_types);
174 }
175
176 // It can be helpful to know if there are any non-literal errors,
177 // as this is justification for not pushing down literal errors.
178 let all_errors = predicates.iter().all(|p| p.is_literal_err());
179 // Depending on the type of `input` we have different
180 // logic to apply to consider pushing `predicates` down.
181 match &mut **input {
182 MirRelationExpr::Let { body, .. }
183 | MirRelationExpr::LetRec { body, .. } => {
184 // Push all predicates to the body.
185 **body = body
186 .take_dangerous()
187 .filter(std::mem::replace(predicates, Vec::new()));
188
189 self.action(input, get_predicates)?;
190 }
191 MirRelationExpr::Get { id, .. } => {
192 // We can report the predicates upward in `get_predicates`,
193 // but we are not yet able to delete them from the
194 // `Filter`.
195 get_predicates
196 .entry(*id)
197 .or_insert_with(|| predicates.iter().cloned().collect())
198 .retain(|p| predicates.contains(p));
199 }
200 MirRelationExpr::Join {
201 inputs,
202 equivalences,
203 ..
204 } => {
205 // We want to scan `predicates` for any that can
206 // 1) become join variable constraints
207 // 2) apply to individual elements of `inputs`.
208 // Figuring out the set of predicates that belong to
209 // the latter group requires 1) knowing which predicates
210 // are in the former group and 2) that the variable
211 // constraints be in canonical form.
212 // Thus, there is a first scan across `predicates` to
213 // populate the join variable constraints
214 // and a second scan across the remaining predicates
215 // to see which ones can become individual elements of
216 // `inputs`.
217
218 let input_mapper = mz_expr::JoinInputMapper::new(inputs);
219
220 // Predicates not translated into join variable
221 // constraints. We will attempt to push them at all
222 // inputs, and failing to
223 let mut pred_not_translated = Vec::new();
224
225 for mut predicate in predicates.drain(..) {
226 use mz_expr::{BinaryFunc, UnaryFunc};
227 if let MirScalarExpr::CallBinary {
228 func: BinaryFunc::Eq(_),
229 expr1,
230 expr2,
231 } = &predicate
232 {
233 // Translate into join variable constraints:
234 // 1) `nonliteral1 == nonliteral2` constraints
235 // 2) `expr == literal` where `expr` refers to more
236 // than one input.
237 let input_count =
238 input_mapper.lookup_inputs(&predicate).count();
239 if (!expr1.is_literal() && !expr2.is_literal())
240 || input_count >= 2
241 {
242 // `col1 == col2` as a `MirScalarExpr`
243 // implies `!isnull(col1)` as well.
244 // `col1 == col2` as a join constraint does
245 // not have this extra implication.
246 // Thus, when translating the
247 // `MirScalarExpr` to a join constraint, we
248 // need to retain the `!isnull(col1)`
249 // information.
250 if expr1.typ(&input_type.column_types).nullable {
251 pred_not_translated.push(
252 expr1
253 .clone()
254 .call_unary(UnaryFunc::IsNull(func::IsNull))
255 .call_unary(UnaryFunc::Not(func::Not)),
256 );
257 } else if expr2.typ(&input_type.column_types).nullable {
258 pred_not_translated.push(
259 expr2
260 .clone()
261 .call_unary(UnaryFunc::IsNull(func::IsNull))
262 .call_unary(UnaryFunc::Not(func::Not)),
263 );
264 }
265 equivalences
266 .push(vec![(**expr1).clone(), (**expr2).clone()]);
267 continue;
268 }
269 } else if let Some((expr1, expr2)) =
270 Self::extract_equal_or_both_null(
271 &mut predicate,
272 &input_type.column_types,
273 )
274 {
275 // Also translate into join variable constraints:
276 // 3) `((nonliteral1 = nonliteral2) || (nonliteral
277 // is null && nonliteral2 is null))`
278 equivalences.push(vec![expr1, expr2]);
279 continue;
280 }
281 pred_not_translated.push(predicate)
282 }
283
284 mz_expr::canonicalize::canonicalize_equivalences(
285 equivalences,
286 std::iter::once(&input_type.column_types),
287 );
288
289 let (retain, push_downs) = Self::push_filters_through_join(
290 &input_mapper,
291 equivalences,
292 pred_not_translated,
293 );
294
295 Self::update_join_inputs_with_push_downs(inputs, push_downs);
296
297 // Recursively descend on the join
298 self.action(input, get_predicates)?;
299
300 // remove all predicates that were pushed down from the current Filter node
301 *predicates = retain;
302 }
303 MirRelationExpr::Reduce {
304 input: inner,
305 group_key,
306 aggregates,
307 monotonic: _,
308 expected_group_size: _,
309 } => {
310 let mut retain = Vec::new();
311 let mut push_down = Vec::new();
312 for predicate in predicates.drain(..) {
313 // Do not push down literal errors unless it is only errors.
314 if !predicate.is_literal_err() || all_errors {
315 let mut supported = true;
316 let mut new_predicate = predicate.clone();
317 new_predicate.visit_pre(|e| {
318 if let MirScalarExpr::Column(c, _) = e {
319 if *c >= group_key.len() {
320 supported = false;
321 }
322 }
323 });
324 if supported {
325 new_predicate.visit_mut_post(&mut |e| {
326 if let MirScalarExpr::Column(i, _) = e {
327 *e = group_key[*i].clone();
328 }
329 })?;
330 push_down.push(new_predicate);
331 } else if let MirScalarExpr::Column(col, _) = &predicate {
332 if *col == group_key.len()
333 && aggregates.len() == 1
334 && aggregates[0].func == AggregateFunc::Any
335 {
336 push_down.push(aggregates[0].expr.clone());
337 aggregates[0].expr = MirScalarExpr::literal_ok(
338 Datum::True,
339 ReprScalarType::Bool,
340 );
341 } else {
342 retain.push(predicate);
343 }
344 } else {
345 retain.push(predicate);
346 }
347 } else {
348 retain.push(predicate);
349 }
350 }
351
352 if !push_down.is_empty() {
353 **inner = inner.take_dangerous().filter(push_down);
354 }
355 self.action(inner, get_predicates)?;
356
357 // remove all predicates that were pushed down from the current Filter node
358 std::mem::swap(&mut retain, predicates);
359 }
360 MirRelationExpr::TopK {
361 input,
362 group_key,
363 order_key: _,
364 limit,
365 offset: _,
366 monotonic: _,
367 expected_group_size: _,
368 } => {
369 let mut retain = Vec::new();
370 let mut push_down = Vec::new();
371
372 let mut support = BTreeSet::new();
373 support.extend(group_key.iter().cloned());
374 if let Some(limit) = limit {
375 // Strictly speaking not needed because the
376 // `limit` support should be a subset of the
377 // `group_key` support, but we don't want to
378 // take this for granted here.
379 limit.support_into(&mut support);
380 }
381
382 for predicate in predicates.drain(..) {
383 // Do not push down literal errors unless it is only errors.
384 if (!predicate.is_literal_err() || all_errors)
385 && predicate.support().is_subset(&support)
386 {
387 push_down.push(predicate);
388 } else {
389 retain.push(predicate);
390 }
391 }
392
393 // remove all predicates that were pushed down from the current Filter node
394 std::mem::swap(&mut retain, predicates);
395
396 if !push_down.is_empty() {
397 **input = input.take_dangerous().filter(push_down);
398 }
399
400 self.action(input, get_predicates)?;
401 }
402 MirRelationExpr::Threshold { input } => {
403 let predicates = std::mem::take(predicates);
404 *relation = input.take_dangerous().filter(predicates).threshold();
405 self.action(relation, get_predicates)?;
406 }
407 MirRelationExpr::Project { input, outputs } => {
408 let predicates = predicates.drain(..).map(|mut predicate| {
409 predicate.permute(outputs);
410 predicate
411 });
412 *relation = input
413 .take_dangerous()
414 .filter(predicates)
415 .project(outputs.clone());
416
417 self.action(relation, get_predicates)?;
418 }
419 MirRelationExpr::Filter {
420 input,
421 predicates: predicates2,
422 } => {
423 *relation = input
424 .take_dangerous()
425 .filter(predicates.clone().into_iter().chain(predicates2.clone()));
426 self.action(relation, get_predicates)?;
427 }
428 MirRelationExpr::Map { input, scalars } => {
429 let (retained, pushdown) = Self::push_filters_through_map(
430 scalars,
431 predicates,
432 input.arity(),
433 all_errors,
434 )?;
435 let scalars = std::mem::take(scalars);
436 let mut result = input.take_dangerous();
437 if !pushdown.is_empty() {
438 result = result.filter(pushdown);
439 }
440 self.action(&mut result, get_predicates)?;
441 result = result.map(scalars);
442 if !retained.is_empty() {
443 result = result.filter(retained);
444 }
445 *relation = result;
446 }
447 MirRelationExpr::FlatMap { input, .. } => {
448 let (mut retained, pushdown) =
449 Self::push_filters_through_flat_map(predicates, input.arity());
450
451 // remove all predicates that were pushed down from the current Filter node
452 std::mem::swap(&mut retained, predicates);
453
454 if !pushdown.is_empty() {
455 // put the filter on top of the input
456 **input = input.take_dangerous().filter(pushdown);
457 }
458
459 // ... and keep pushing predicates down
460 self.action(input, get_predicates)?;
461 }
462 MirRelationExpr::Union { base, inputs } => {
463 let predicates = std::mem::take(predicates);
464 **base = base.take_dangerous().filter(predicates.clone());
465 self.action(base, get_predicates)?;
466 for input in inputs {
467 *input = input.take_dangerous().filter(predicates.clone());
468 self.action(input, get_predicates)?;
469 }
470 }
471 MirRelationExpr::Negate { input } => {
472 // Don't push literal errors past a Negate. The problem is that it's
473 // hard to appropriately reflect the negation in the error stream:
474 // - If we don't negate, then errors that should cancel out will not
475 // cancel out. For example, see
476 // https://github.com/MaterializeInc/database-issues/issues/5691
477 // - If we negate, then unrelated errors might cancel out. E.g., there
478 // might be a division-by-0 in both inputs to an EXCEPT ALL, but
479 // on different input data. These shouldn't cancel out.
480 let (retained, pushdown): (Vec<_>, Vec<_>) = std::mem::take(predicates)
481 .into_iter()
482 .partition(|p| p.is_literal_err());
483 let mut result = input.take_dangerous();
484 if !pushdown.is_empty() {
485 result = result.filter(pushdown);
486 }
487 self.action(&mut result, get_predicates)?;
488 result = result.negate();
489 if !retained.is_empty() {
490 result = result.filter(retained);
491 }
492 *relation = result;
493 }
494 x => {
495 x.try_visit_mut_children(|e| self.action(e, get_predicates))?;
496 }
497 }
498
499 // remove empty filters (junk by-product of the actual transform)
500 match relation {
501 MirRelationExpr::Filter { predicates, input } if predicates.is_empty() => {
502 *relation = input.take_dangerous();
503 }
504 _ => {}
505 }
506
507 Ok(())
508 }
509 MirRelationExpr::Get { id, .. } => {
510 // Purge all predicates associated with the id.
511 get_predicates
512 .entry(*id)
513 .or_insert_with(BTreeSet::new)
514 .clear();
515
516 Ok(())
517 }
518 MirRelationExpr::Let { id, body, value } => {
519 // Push predicates and collect intersection at `Get`s.
520 self.action(body, get_predicates)?;
521
522 // `get_predicates` should now contain the intersection
523 // of predicates at each *use* of the binding. If it is
524 // non-empty, we can move those predicates to the value.
525 Self::push_into_let_binding(get_predicates, id, value, &mut [body]);
526
527 // Continue recursively on the value.
528 self.action(value, get_predicates)
529 }
530 MirRelationExpr::LetRec {
531 ids,
532 values,
533 limits: _,
534 body,
535 } => {
536 // Note: This could be extended to be able to do a little more pushdowns, see
537 // https://github.com/MaterializeInc/database-issues/issues/5336#issuecomment-1477588262
538
539 // Pre-compute which Ids are used across iterations
540 let ids_used_across_iterations = MirRelationExpr::recursive_ids(ids, values);
541
542 // Predicate pushdown within the body
543 self.action(body, get_predicates)?;
544
545 // `users` will be the body plus the values of those bindings that we have seen
546 // so far, while going one-by-one through the list of bindings backwards.
547 // `users` contains those expressions from which we harvested `get_predicates`,
548 // and therefore we should attend to all of these expressions when pushing down
549 // a predicate into a Let binding.
550 let mut users = vec![&mut **body];
551 for (id, value) in ids.iter_mut().rev().zip_eq(values.into_iter().rev()) {
552 // Predicate pushdown from Gets in `users` into the value of a Let binding
553 //
554 // For now, we simply always avoid pushing into a Let binding that is
555 // referenced across iterations to avoid soundness problems and infinite
556 // pushdowns.
557 //
558 // Note that `push_into_let_binding` makes a further check based on
559 // `get_predicates`: We push a predicate into the value of a binding, only
560 // if all Gets of this Id have this same predicate on top of them.
561 if !ids_used_across_iterations.contains(id) {
562 Self::push_into_let_binding(get_predicates, id, value, &mut users);
563 }
564
565 // Predicate pushdown within a binding
566 self.action(value, get_predicates)?;
567
568 users.push(value);
569 }
570
571 Ok(())
572 }
573 MirRelationExpr::Join {
574 inputs,
575 equivalences,
576 ..
577 } => {
578 // The goal is to push
579 // 1) equivalences of the form `expr = <runtime constant>`, where `expr`
580 // comes from a single input.
581 // 2) equivalences of the form `expr1 = expr2`, where both
582 // expressions come from the same single input.
583 let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
584 mz_expr::canonicalize::canonicalize_equivalences(
585 equivalences,
586 input_types.iter().map(|t| &t.column_types),
587 );
588
589 let input_mapper = mz_expr::JoinInputMapper::new_from_input_types(&input_types);
590 // Predicates to push at each input, and to lift out the join.
591 let mut push_downs = vec![Vec::new(); inputs.len()];
592
593 for equivalence_pos in 0..equivalences.len() {
594 // Case 1: there are more than one literal in the
595 // equivalence class. Because of equivalences have been
596 // dedupped, this means that everything in the equivalence
597 // class must be equal to two different literals, so the
598 // entire relation zeroes out
599 if equivalences[equivalence_pos]
600 .iter()
601 .filter(|expr| expr.is_literal())
602 .count()
603 > 1
604 {
605 relation.take_safely(Some(relation.typ_with_input_types(&input_types)));
606 return Ok(());
607 }
608
609 let runtime_constants = equivalences[equivalence_pos]
610 .iter()
611 .filter(|expr| expr.support().is_empty())
612 .cloned()
613 .collect::<Vec<_>>();
614 if !runtime_constants.is_empty() {
615 // Case 2: There is at least one runtime constant the equivalence class
616 let gen_literal_equality_preds = |expr: MirScalarExpr| {
617 let mut equality_preds = Vec::new();
618 for constant in runtime_constants.iter() {
619 let pred = if constant.is_literal_null() {
620 MirScalarExpr::CallUnary {
621 func: mz_expr::UnaryFunc::IsNull(func::IsNull),
622 expr: Box::new(expr.clone()),
623 }
624 } else {
625 MirScalarExpr::CallBinary {
626 func: mz_expr::func::Eq.into(),
627 expr1: Box::new(expr.clone()),
628 expr2: Box::new(constant.clone()),
629 }
630 };
631 equality_preds.push(pred);
632 }
633 equality_preds
634 };
635
636 // Find all single input expressions in the equivalence
637 // class and collect (position within the equivalence class,
638 // input the expression belongs to, localized version of the
639 // expression).
640 let mut single_input_exprs = equivalences[equivalence_pos]
641 .iter()
642 .enumerate()
643 .filter_map(|(pos, e)| {
644 let mut inputs = input_mapper.lookup_inputs(e);
645 if let Some(input) = inputs.next() {
646 if inputs.next().is_none() {
647 return Some((
648 pos,
649 input,
650 input_mapper.map_expr_to_local(e.clone()),
651 ));
652 }
653 }
654 None
655 })
656 .collect::<Vec<_>>();
657
658 // For every single-input expression `expr`, we can push
659 // down `expr = <runtime constant>` and remove `expr` from the
660 // equivalence class.
661 for (expr_pos, input, expr) in single_input_exprs.drain(..).rev() {
662 push_downs[input].extend(gen_literal_equality_preds(expr));
663 equivalences[equivalence_pos].remove(expr_pos);
664 }
665
666 // If none of the expressions in the equivalence depend on input
667 // columns and equality predicates with them are pushed down,
668 // we can safely remove them from the equivalence.
669 // TODO: we could probably push equality predicates among the
670 // remaining constants to all join inputs to prevent any computation
671 // from happening until the condition is satisfied.
672 if equivalences[equivalence_pos]
673 .iter()
674 .all(|e| e.support().is_empty())
675 && push_downs.iter().any(|p| !p.is_empty())
676 {
677 equivalences[equivalence_pos].clear();
678 }
679 } else {
680 // Case 3: There are no constants in the equivalence
681 // class. Push a predicate for every pair of expressions
682 // in the equivalence that either belong to a single
683 // input or can be localized to a given input through
684 // the rest of equivalences.
685 let mut to_remove = Vec::new();
686 for input in 0..inputs.len() {
687 // Vector of pairs (position within the equivalence, localized
688 // expression). The position is None for expressions derived through
689 // other equivalences.
690 let localized = equivalences[equivalence_pos]
691 .iter()
692 .enumerate()
693 .filter_map(|(pos, expr)| {
694 if let MirScalarExpr::Column(col_pos, _) = &expr {
695 let local_col =
696 input_mapper.map_column_to_local(*col_pos);
697 if input == local_col.1 {
698 // TODO(mgree) !!! is it safe to propagate the name here?
699 return Some((
700 Some(pos),
701 MirScalarExpr::column(local_col.0),
702 ));
703 } else {
704 return None;
705 }
706 }
707 let mut inputs = input_mapper.lookup_inputs(expr);
708 if let Some(single_input) = inputs.next() {
709 if input == single_input && inputs.next().is_none() {
710 return Some((
711 Some(pos),
712 input_mapper.map_expr_to_local(expr.clone()),
713 ));
714 }
715 }
716 // Equivalences not including the current expression
717 let mut other_equivalences = equivalences.clone();
718 other_equivalences[equivalence_pos].remove(pos);
719 let mut localized = expr.clone();
720 if input_mapper.try_localize_to_input_with_bound_expr(
721 &mut localized,
722 input,
723 &other_equivalences[..],
724 ) {
725 Some((None, localized))
726 } else {
727 None
728 }
729 })
730 .collect::<Vec<_>>();
731
732 // If there are at least 2 expression in the equivalence that
733 // can be localized to the same input, push all combinations
734 // of them to the input.
735 if localized.len() > 1 {
736 for mut pair in
737 localized.iter().map(|(_, expr)| expr).combinations(2)
738 {
739 let expr1 = pair.pop().unwrap();
740 let expr2 = pair.pop().unwrap();
741
742 push_downs[input].push(
743 MirScalarExpr::CallBinary {
744 func: func::Eq.into(),
745 expr1: Box::new(expr2.clone()),
746 expr2: Box::new(expr1.clone()),
747 }
748 .or(expr2
749 .clone()
750 .call_is_null()
751 .and(expr1.clone().call_is_null())),
752 );
753 }
754
755 if localized.len() == equivalences[equivalence_pos].len() {
756 // The equivalence is either a single input one or fully localizable
757 // to a single input through other equivalences, so it can be removed
758 // completely without introducing any new cross join.
759 to_remove.extend(0..equivalences[equivalence_pos].len());
760 } else {
761 // Leave an expression from this input in the equivalence to avoid
762 // cross joins
763 to_remove.extend(
764 localized.iter().filter_map(|(pos, _)| *pos).skip(1),
765 );
766 }
767 }
768 }
769
770 // Remove expressions that were pushed down to at least one input
771 to_remove.sort();
772 to_remove.dedup();
773 for pos in to_remove.iter().rev() {
774 equivalences[equivalence_pos].remove(*pos);
775 }
776 };
777 }
778
779 mz_expr::canonicalize::canonicalize_equivalences(
780 equivalences,
781 input_types.iter().map(|t| &t.column_types),
782 );
783
784 Self::update_join_inputs_with_push_downs(inputs, push_downs);
785
786 // Recursively descend on each of the inputs.
787 for input in inputs.iter_mut() {
788 self.action(input, get_predicates)?;
789 }
790
791 Ok(())
792 }
793 x => {
794 // Recursively descend.
795 x.try_visit_mut_children(|e| self.action(e, get_predicates))
796 }
797 }
798 })
799 }
800
801 fn update_join_inputs_with_push_downs(
802 inputs: &mut Vec<MirRelationExpr>,
803 push_downs: Vec<Vec<MirScalarExpr>>,
804 ) {
805 let new_inputs = inputs
806 .drain(..)
807 .zip_eq(push_downs)
808 .map(|(input, push_down)| {
809 if !push_down.is_empty() {
810 input.filter(push_down)
811 } else {
812 input
813 }
814 })
815 .collect();
816 *inputs = new_inputs;
817 }
818
819 // Checks `get_predicates` to see whether we can push a predicate into the Let binding given
820 // by `id` and `value`.
821 // `users` is the list of those expressions from which we will need to remove a predicate that
822 // is being pushed.
823 fn push_into_let_binding(
824 get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
825 id: &LocalId,
826 value: &mut MirRelationExpr,
827 users: &mut [&mut MirRelationExpr],
828 ) {
829 if let Some(list) = get_predicates.remove(&Id::Local(*id)) {
830 if !list.is_empty() {
831 // Remove the predicates in `list` from the users.
832 for user in users {
833 user.visit_pre_mut(|e| {
834 if let MirRelationExpr::Filter { input, predicates } = e {
835 if let MirRelationExpr::Get { id: get_id, .. } = **input {
836 if get_id == Id::Local(*id) {
837 predicates.retain(|p| !list.contains(p));
838 }
839 }
840 }
841 });
842 }
843 // Apply the predicates in `list` to value. Canonicalize
844 // `list` so that plans are always deterministic.
845 let mut list = list.into_iter().collect::<Vec<_>>();
846 mz_expr::canonicalize::canonicalize_predicates(
847 &mut list,
848 &value.typ().column_types,
849 );
850 *value = value.take_dangerous().filter(list);
851 }
852 }
853 }
854
855 /// Returns `(<predicates to retain>, <predicates to push at each input>)`.
856 pub fn push_filters_through_join(
857 input_mapper: &JoinInputMapper,
858 equivalences: &Vec<Vec<MirScalarExpr>>,
859 mut predicates: Vec<MirScalarExpr>,
860 ) -> (Vec<MirScalarExpr>, Vec<Vec<MirScalarExpr>>) {
861 let mut push_downs = vec![Vec::new(); input_mapper.total_inputs()];
862 let mut retain = Vec::new();
863
864 for predicate in predicates.drain(..) {
865 // Track if the predicate has been pushed to at least one input.
866 let mut pushed = false;
867 // For each input, try and see if the join
868 // equivalences allow the predicate to be rewritten
869 // in terms of only columns from that input.
870 for (index, push_down) in push_downs.iter_mut().enumerate() {
871 #[allow(deprecated)] // TODO: use `might_error` if possible.
872 if predicate.is_literal_err() || predicate.contains_error_if_null() {
873 // Do nothing. We don't push down literal errors,
874 // as we can't know the join will be non-empty.
875 //
876 // We also don't want to push anything that involves `error_if_null`. This is
877 // for the same reason why in theory we shouldn't really push anything that can
878 // error, assuming that we want to preserve error semantics. (Because we would
879 // create a spurious error if some other Join input ends up empty.) We can't fix
880 // this problem in general (as we can't just not push anything that might
881 // error), but we decided to fix the specific problem instance involving
882 // `error_if_null`, because it was very painful:
883 // <https://github.com/MaterializeInc/database-issues/issues/6258>
884 } else {
885 let mut localized = predicate.clone();
886 if input_mapper.try_localize_to_input_with_bound_expr(
887 &mut localized,
888 index,
889 equivalences,
890 ) {
891 push_down.push(localized);
892 pushed = true;
893 } else if let Some(consequence) = input_mapper
894 // (`consequence_for_input` assumes that
895 // `try_localize_to_input_with_bound_expr` has already
896 // been called on `localized`.)
897 .consequence_for_input(&localized, index)
898 {
899 push_down.push(consequence);
900 // We don't set `pushed` here! We want to retain the
901 // predicate, because we only pushed a consequence of
902 // it, but not the full predicate.
903 }
904 }
905 }
906
907 if !pushed {
908 retain.push(predicate);
909 }
910 }
911
912 (retain, push_downs)
913 }
914
915 /// Computes "safe" predicates to push through a Map.
916 ///
917 /// In the case of a Filter { Map {...} }, we can always push down the Filter
918 /// by inlining expressions from the Map. We don't want to do this in general,
919 /// however, since general inlining can result in exponential blowup in the size
920 /// of expressions, so we only do this in the case where the size after inlining
921 /// is below a certain limit.
922 ///
923 /// Returns the predicates that can be pushed down, followed by ones that cannot.
924 pub fn push_filters_through_map(
925 map_exprs: &Vec<MirScalarExpr>,
926 predicates: &mut Vec<MirScalarExpr>,
927 input_arity: usize,
928 all_errors: bool,
929 ) -> Result<(Vec<MirScalarExpr>, Vec<MirScalarExpr>), TransformError> {
930 let mut pushdown = Vec::new();
931 let mut retained = Vec::new();
932 for predicate in predicates.drain(..) {
933 // We don't push down literal errors, unless all predicates are.
934 if !predicate.is_literal_err() || all_errors {
935 // Consider inlining Map expressions.
936 if let Some(cleaned) =
937 Self::inline_if_not_too_big(&predicate, input_arity, map_exprs)?
938 {
939 pushdown.push(cleaned);
940 } else {
941 retained.push(predicate);
942 }
943 } else {
944 retained.push(predicate);
945 }
946 }
947 Ok((retained, pushdown))
948 }
949
950 /// This fn should be called with a Filter `expr` that is after a Map. `input_arity` is the
951 /// arity of the input of the Map. This fn eliminates such column refs in `expr` that refer not
952 /// to a column in the input of the Map, but to a column that is created by the Map. It does
953 /// this by transitively inlining Map expressions until no such expression remains that points
954 /// to a Map expression. The return value is the cleaned up expression. The fn bails out with a
955 /// None if the resulting expression would be made too big by the inlinings.
956 ///
957 /// OOO: (Optimizer Optimization Opportunity) This function might do work proportional to the
958 /// total size of the Map expressions. We call this function for each predicate above the Map,
959 /// which will be kind of quadratic, i.e., if there are many predicates and a big Map, then this
960 /// will be slow. We could instead pass a vector of Map expressions and call this fn only once.
961 /// The only downside would be that then the inlining limit being hit in the middle part of this
962 /// function would prevent us from inlining any predicates, even ones that wouldn't hit the
963 /// inlining limit if considered on their own.
964 fn inline_if_not_too_big(
965 expr: &MirScalarExpr,
966 input_arity: usize,
967 map_exprs: &Vec<MirScalarExpr>,
968 ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
969 let size_limit = 1000;
970
971 // Transitively determine the support of `expr` produced by `map_exprs`
972 // that needs to be inlined.
973 let cols_to_inline = {
974 let mut support = BTreeSet::new();
975
976 // Seed with `map_exprs` support in `expr`.
977 expr.visit_pre(|e| {
978 if let MirScalarExpr::Column(c, _) = e {
979 if *c >= input_arity {
980 support.insert(*c);
981 }
982 }
983 });
984
985 // Compute transitive closure of supports in `map_exprs`.
986 let mut workset = support.iter().cloned().collect::<Vec<_>>();
987 let mut buffer = vec![];
988 while !workset.is_empty() {
989 // Swap the (empty) `drained` buffer with the `workset`.
990 std::mem::swap(&mut workset, &mut buffer);
991 // Drain the `buffer` and update `support` and `workset`.
992 for c in buffer.drain(..) {
993 map_exprs[c - input_arity].visit_pre(|e| {
994 if let MirScalarExpr::Column(c, _) = e {
995 if *c >= input_arity {
996 if support.insert(*c) {
997 workset.push(*c);
998 }
999 }
1000 }
1001 });
1002 }
1003 }
1004 support
1005 };
1006
1007 let mut inlined = BTreeMap::<usize, (MirScalarExpr, usize)>::new();
1008 // Populate the memo table in ascending column order (which respects the
1009 // dependency order of `map_exprs` references). Break early if memoization
1010 // fails for one of the columns in `cols_to_inline`.
1011 for c in cols_to_inline.iter() {
1012 let mut new_expr = map_exprs[*c - input_arity].clone();
1013 let mut new_size = 0;
1014 new_expr.visit_mut_post(&mut |expr| {
1015 new_size += 1;
1016 if let MirScalarExpr::Column(c, _) = expr {
1017 if *c >= input_arity && new_size <= size_limit {
1018 // (inlined[c] is safe, because we proceed in column order, and we break out
1019 // of the loop when we stop inserting into memo.)
1020 let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1021 *expr = m_expr.clone();
1022 new_size += m_size - 1; // Adjust for the +1 above.
1023 }
1024 }
1025 })?;
1026
1027 if new_size <= size_limit {
1028 inlined.insert(*c, (new_expr, new_size));
1029 } else {
1030 break;
1031 }
1032 }
1033
1034 // Try to resolve expr against the memo table.
1035 if inlined.len() < cols_to_inline.len() {
1036 Ok(None) // We couldn't memoize all map expressions within the given limit.
1037 } else {
1038 let mut new_expr = expr.clone();
1039 let mut new_size = 0;
1040 new_expr.visit_mut_post(&mut |expr| {
1041 new_size += 1;
1042 if let MirScalarExpr::Column(c, _) = expr {
1043 if *c >= input_arity && new_size <= size_limit {
1044 // (inlined[c] is safe because of the outer if condition.)
1045 let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1046 *expr = m_expr.clone();
1047 new_size += m_size - 1; // Adjust for the +1 above.
1048 }
1049 }
1050 })?;
1051
1052 soft_assert_eq_no_log!(new_size, new_expr.size());
1053 if new_size <= size_limit {
1054 Ok(Some(new_expr)) // We managed to stay within the limit.
1055 } else {
1056 Ok(None) // Limit exceeded.
1057 }
1058 }
1059 }
1060 // fn inline_if_not_too_big(
1061 // expr: &MirScalarExpr,
1062 // input_arity: usize,
1063 // map_exprs: &Vec<MirScalarExpr>,
1064 // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1065 // let size_limit = 1000;
1066 // // Memoize cleaned up versions of Map expressions. (Not necessarily all the Map expressions
1067 // // will be involved.)
1068 // let mut memo: BTreeMap<MirScalarExpr, MirScalarExpr> = BTreeMap::new();
1069 // fn rec(
1070 // expr: &MirScalarExpr,
1071 // input_arity: usize,
1072 // map_exprs: &Vec<MirScalarExpr>,
1073 // memo: &mut BTreeMap<MirScalarExpr, MirScalarExpr>,
1074 // size_limit: usize,
1075 // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1076 // // (We can't use Entry::or_insert_with, because the closure would need to be fallible.
1077 // // We also can't manually match on the result of memo.entry, because that holds a
1078 // // borrow of memo, but we need to pass memo to the recursive call in the middle.)
1079 // match memo.get(expr) {
1080 // Some(memoized_result) => Ok(Some(memoized_result.clone())),
1081 // None => {
1082 // let mut expr_size = expr.size()?;
1083 // let mut cleaned_expr = expr.clone();
1084 // let mut bail = false;
1085 // cleaned_expr.try_visit_mut_post(&mut |expr| {
1086 // Ok(if !bail {
1087 // match expr {
1088 // MirScalarExpr::Column(col) => {
1089 // if *col >= input_arity {
1090 // let to_inline = rec(
1091 // &map_exprs[*col - input_arity],
1092 // input_arity,
1093 // map_exprs,
1094 // memo,
1095 // size_limit,
1096 // )?;
1097 // if let Some(to_inline) = to_inline {
1098 // // The `-1` is because the expression that we are
1099 // // replacing has a size of 1.
1100 // expr_size += to_inline.size()? - 1;
1101 // *expr = to_inline;
1102 // if expr_size > size_limit {
1103 // bail = true;
1104 // }
1105 // } else {
1106 // bail = true;
1107 // }
1108 // }
1109 // }
1110 // _ => (),
1111 // }
1112 // })
1113 // })?;
1114 // soft_assert_eq!(cleaned_expr.size()?, expr_size);
1115 // if !bail {
1116 // memo.insert(expr.clone(), cleaned_expr.clone());
1117 // Ok(Some(cleaned_expr))
1118 // } else {
1119 // Ok(None)
1120 // }
1121 // }
1122 // }
1123 // }
1124 // rec(expr, input_arity, map_exprs, &mut memo, size_limit)
1125 // }
1126
1127 /// Computes "safe" predicate to push through a FlatMap.
1128 ///
1129 /// In the case of a Filter { FlatMap {...} }, we want to push through all predicates
1130 /// that (1) are not literal errors and (2) have support exclusively in the columns
1131 /// provided by the FlatMap input.
1132 ///
1133 /// Returns the predicates that can be pushed down, followed by ones that cannot.
1134 fn push_filters_through_flat_map(
1135 predicates: &mut Vec<MirScalarExpr>,
1136 input_arity: usize,
1137 ) -> (Vec<MirScalarExpr>, Vec<MirScalarExpr>) {
1138 let mut pushdown = Vec::new();
1139 let mut retained = Vec::new();
1140 for predicate in predicates.drain(..) {
1141 // First, check if we can push this predicate down. We can do so if and only if:
1142 // (1) the predicate is not a literal error, and
1143 // (2) each column it references is from the input.
1144 if (!predicate.is_literal_err()) && predicate.support().iter().all(|c| *c < input_arity)
1145 {
1146 pushdown.push(predicate);
1147 } else {
1148 retained.push(predicate);
1149 }
1150 }
1151 (retained, pushdown)
1152 }
1153
1154 /// If `s` is of the form
1155 /// `(isnull(expr1) && isnull(expr2)) || (expr1 = expr2)`, or
1156 /// `(decompose_is_null(expr1) && decompose_is_null(expr2)) || (expr1 = expr2)`,
1157 /// extract `expr1` and `expr2`.
1158 fn extract_equal_or_both_null(
1159 s: &mut MirScalarExpr,
1160 column_types: &[ReprColumnType],
1161 ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1162 if let MirScalarExpr::CallVariadic {
1163 func: VariadicFunc::Or(_),
1164 exprs,
1165 } = s
1166 {
1167 if let &[ref or_lhs, ref or_rhs] = &**exprs {
1168 // Check both orders of operands of the OR
1169 return Self::extract_equal_or_both_null_inner(or_lhs, or_rhs, column_types)
1170 .or_else(|| {
1171 Self::extract_equal_or_both_null_inner(or_rhs, or_lhs, column_types)
1172 });
1173 }
1174 }
1175 None
1176 }
1177
1178 fn extract_equal_or_both_null_inner(
1179 or_arg1: &MirScalarExpr,
1180 or_arg2: &MirScalarExpr,
1181 column_types: &[ReprColumnType],
1182 ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1183 use mz_expr::BinaryFunc;
1184 if let MirScalarExpr::CallBinary {
1185 func: BinaryFunc::Eq(_),
1186 expr1: eq_lhs,
1187 expr2: eq_rhs,
1188 } = &or_arg2
1189 {
1190 let isnull1 = eq_lhs.clone().call_is_null();
1191 let isnull2 = eq_rhs.clone().call_is_null();
1192 let both_null = MirScalarExpr::call_variadic(And, vec![isnull1, isnull2]);
1193
1194 if Self::extract_reduced_conjunction_terms(both_null, column_types)
1195 == Self::extract_reduced_conjunction_terms(or_arg1.clone(), column_types)
1196 {
1197 return Some(((**eq_lhs).clone(), (**eq_rhs).clone()));
1198 }
1199 }
1200 None
1201 }
1202
1203 /// Reduces the given expression and returns its AND-ed terms.
1204 fn extract_reduced_conjunction_terms(
1205 mut s: MirScalarExpr,
1206 column_types: &[ReprColumnType],
1207 ) -> Vec<MirScalarExpr> {
1208 s.reduce(column_types);
1209
1210 if let MirScalarExpr::CallVariadic {
1211 func: VariadicFunc::And(_),
1212 exprs,
1213 } = s
1214 {
1215 exprs
1216 } else {
1217 vec![s]
1218 }
1219 }
1220}