mz_transform/predicate_pushdown.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes predicates down through other operators.
11//!
12//! This action generally improves the quality of the query, in that selective per-record
13//! filters reduce the volume of data before they arrive at more expensive operators.
14//!
15//!
16//! The one time when this action might not improve the quality of a query is
17//! if a filter gets pushed down on an arrangement because that blocks arrangement
18//! reuse. It assumed that actions that need an arrangement are responsible for
19//! lifting filters out of the way.
20//!
21//! Predicate pushdown will not push down literal errors, unless it is certain that
22//! the literal errors will be unconditionally evaluated. For example, the pushdown
23//! will not happen if not all predicates can be pushed down (e.g. reduce and map),
24//! or if we are not certain that the input is non-empty (e.g. join).
25//! Note that this is not addressing the problem in its full generality, because this problem can
26//! occur with any function call that might error (although much more rarely than with literal
27//! errors). See <https://github.com/MaterializeInc/database-issues/issues/4972#issuecomment-1547391011>
28//!
29//! ```rust
30//! use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr, func};
31//! use mz_ore::id_gen::IdGen;
32//! use mz_repr::{SqlColumnType, Datum, SqlRelationType, SqlScalarType};
33//! use mz_repr::optimize::OptimizerFeatures;
34//! use mz_transform::{reprtypecheck, typecheck,Transform, TransformCtx};
35//! use mz_transform::dataflow::DataflowMetainfo;
36//!
37//! use mz_transform::predicate_pushdown::PredicatePushdown;
38//!
39//! let input1 = MirRelationExpr::constant(vec![], SqlRelationType::new(vec![
40//! SqlScalarType::Bool.nullable(false),
41//! ]));
42//! let input2 = MirRelationExpr::constant(vec![], SqlRelationType::new(vec![
43//! SqlScalarType::Bool.nullable(false),
44//! ]));
45//! let input3 = MirRelationExpr::constant(vec![], SqlRelationType::new(vec![
46//! SqlScalarType::Bool.nullable(false),
47//! ]));
48//! let join = MirRelationExpr::join(
49//! vec![input1.clone(), input2.clone(), input3.clone()],
50//! vec![vec![(0, 0), (2, 0)].into_iter().collect()],
51//! );
52//!
53//! let predicate0 = MirScalarExpr::column(0);
54//! let predicate1 = MirScalarExpr::column(1);
55//! let predicate01 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(2), func::AddInt64);
56//! let predicate012 = MirScalarExpr::literal_false();
57//!
58//! let mut expr = join.filter(
59//! vec![
60//! predicate0.clone(),
61//! predicate1.clone(),
62//! predicate01.clone(),
63//! predicate012.clone(),
64//! ]);
65//!
66//! let features = OptimizerFeatures::default();
67//! let typecheck_ctx = typecheck::empty_context();
68//! let repr_typecheck_ctx = reprtypecheck::empty_context();
69//! let mut df_meta = DataflowMetainfo::default();
70//! let mut transform_ctx = TransformCtx::local(&features, &typecheck_ctx, &repr_typecheck_ctx, &mut df_meta, None, None);
71//!
72//! PredicatePushdown::default().transform(&mut expr, &mut transform_ctx);
73//!
74//! let predicate00 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(0), func::AddInt64);
75//! let expected_expr = MirRelationExpr::join(
76//! vec![
77//! input1.clone().filter(vec![predicate0.clone(), predicate00.clone()]),
78//! input2.clone().filter(vec![predicate0.clone()]),
79//! input3.clone().filter(vec![predicate0, predicate00])
80//! ],
81//! vec![vec![(0, 0), (2, 0)].into_iter().collect()],
82//! ).filter(vec![predicate012]);
83//! assert_eq!(expected_expr, expr)
84//! ```
85
86use std::collections::{BTreeMap, BTreeSet};
87
88use itertools::Itertools;
89use mz_expr::visit::{Visit, VisitChildren};
90use mz_expr::{
91 AggregateFunc, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
92 VariadicFunc, func,
93};
94use mz_ore::soft_assert_eq_no_log;
95use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
96use mz_repr::{Datum, SqlColumnType, SqlScalarType};
97
98use crate::{TransformCtx, TransformError};
99
100/// Pushes predicates down through other operators.
101#[derive(Debug)]
102pub struct PredicatePushdown {
103 recursion_guard: RecursionGuard,
104}
105
106impl Default for PredicatePushdown {
107 fn default() -> PredicatePushdown {
108 PredicatePushdown {
109 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
110 }
111 }
112}
113
114impl CheckedRecursion for PredicatePushdown {
115 fn recursion_guard(&self) -> &RecursionGuard {
116 &self.recursion_guard
117 }
118}
119
120impl crate::Transform for PredicatePushdown {
121 fn name(&self) -> &'static str {
122 "PredicatePushdown"
123 }
124
125 #[mz_ore::instrument(
126 target = "optimizer",
127 level = "debug",
128 fields(path.segment = "predicate_pushdown")
129 )]
130 fn actually_perform_transform(
131 &self,
132 relation: &mut MirRelationExpr,
133 _: &mut TransformCtx,
134 ) -> Result<(), TransformError> {
135 let mut empty = BTreeMap::new();
136 let result = self.action(relation, &mut empty);
137 mz_repr::explain::trace_plan(&*relation);
138 result
139 }
140}
141
142impl PredicatePushdown {
143 /// Predicate pushdown
144 ///
145 /// This method looks for opportunities to push predicates toward
146 /// sources of data. Primarily, this is the `Filter` expression,
147 /// and moving its predicates through the operators it contains.
148 ///
149 /// In addition, the method accumulates the intersection of predicates
150 /// applied to each `Get` expression, so that the predicate can
151 /// then be pushed through to a `Let` binding, or to the external
152 /// source of the data if the `Get` binds to another view.
153 pub fn action(
154 &self,
155 relation: &mut MirRelationExpr,
156 get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
157 ) -> Result<(), TransformError> {
158 self.checked_recur(|_| {
159 // In the case of Filter or Get we have specific work to do;
160 // otherwise we should recursively descend.
161 match relation {
162 MirRelationExpr::Filter { input, predicates } => {
163 // Reduce the predicates to determine as best as possible
164 // whether they are literal errors before working with them.
165 let input_type = input.typ();
166 for predicate in predicates.iter_mut() {
167 predicate.reduce(&input_type.column_types);
168 }
169
170 // It can be helpful to know if there are any non-literal errors,
171 // as this is justification for not pushing down literal errors.
172 let all_errors = predicates.iter().all(|p| p.is_literal_err());
173 // Depending on the type of `input` we have different
174 // logic to apply to consider pushing `predicates` down.
175 match &mut **input {
176 MirRelationExpr::Let { body, .. }
177 | MirRelationExpr::LetRec { body, .. } => {
178 // Push all predicates to the body.
179 **body = body
180 .take_dangerous()
181 .filter(std::mem::replace(predicates, Vec::new()));
182
183 self.action(input, get_predicates)?;
184 }
185 MirRelationExpr::Get { id, .. } => {
186 // We can report the predicates upward in `get_predicates`,
187 // but we are not yet able to delete them from the
188 // `Filter`.
189 get_predicates
190 .entry(*id)
191 .or_insert_with(|| predicates.iter().cloned().collect())
192 .retain(|p| predicates.contains(p));
193 }
194 MirRelationExpr::Join {
195 inputs,
196 equivalences,
197 ..
198 } => {
199 // We want to scan `predicates` for any that can
200 // 1) become join variable constraints
201 // 2) apply to individual elements of `inputs`.
202 // Figuring out the set of predicates that belong to
203 // the latter group requires 1) knowing which predicates
204 // are in the former group and 2) that the variable
205 // constraints be in canonical form.
206 // Thus, there is a first scan across `predicates` to
207 // populate the join variable constraints
208 // and a second scan across the remaining predicates
209 // to see which ones can become individual elements of
210 // `inputs`.
211
212 let input_mapper = mz_expr::JoinInputMapper::new(inputs);
213
214 // Predicates not translated into join variable
215 // constraints. We will attempt to push them at all
216 // inputs, and failing to
217 let mut pred_not_translated = Vec::new();
218
219 for mut predicate in predicates.drain(..) {
220 use mz_expr::{BinaryFunc, UnaryFunc};
221 if let MirScalarExpr::CallBinary {
222 func: BinaryFunc::Eq(_),
223 expr1,
224 expr2,
225 } = &predicate
226 {
227 // Translate into join variable constraints:
228 // 1) `nonliteral1 == nonliteral2` constraints
229 // 2) `expr == literal` where `expr` refers to more
230 // than one input.
231 let input_count =
232 input_mapper.lookup_inputs(&predicate).count();
233 if (!expr1.is_literal() && !expr2.is_literal())
234 || input_count >= 2
235 {
236 // `col1 == col2` as a `MirScalarExpr`
237 // implies `!isnull(col1)` as well.
238 // `col1 == col2` as a join constraint does
239 // not have this extra implication.
240 // Thus, when translating the
241 // `MirScalarExpr` to a join constraint, we
242 // need to retain the `!isnull(col1)`
243 // information.
244 if expr1.typ(&input_type.column_types).nullable {
245 pred_not_translated.push(
246 expr1
247 .clone()
248 .call_unary(UnaryFunc::IsNull(func::IsNull))
249 .call_unary(UnaryFunc::Not(func::Not)),
250 );
251 } else if expr2.typ(&input_type.column_types).nullable {
252 pred_not_translated.push(
253 expr2
254 .clone()
255 .call_unary(UnaryFunc::IsNull(func::IsNull))
256 .call_unary(UnaryFunc::Not(func::Not)),
257 );
258 }
259 equivalences
260 .push(vec![(**expr1).clone(), (**expr2).clone()]);
261 continue;
262 }
263 } else if let Some((expr1, expr2)) =
264 Self::extract_equal_or_both_null(
265 &mut predicate,
266 &input_type.column_types,
267 )
268 {
269 // Also translate into join variable constraints:
270 // 3) `((nonliteral1 = nonliteral2) || (nonliteral
271 // is null && nonliteral2 is null))`
272 equivalences.push(vec![expr1, expr2]);
273 continue;
274 }
275 pred_not_translated.push(predicate)
276 }
277
278 mz_expr::canonicalize::canonicalize_equivalences(
279 equivalences,
280 std::iter::once(&input_type.column_types),
281 );
282
283 let (retain, push_downs) = Self::push_filters_through_join(
284 &input_mapper,
285 equivalences,
286 pred_not_translated,
287 );
288
289 Self::update_join_inputs_with_push_downs(inputs, push_downs);
290
291 // Recursively descend on the join
292 self.action(input, get_predicates)?;
293
294 // remove all predicates that were pushed down from the current Filter node
295 *predicates = retain;
296 }
297 MirRelationExpr::Reduce {
298 input: inner,
299 group_key,
300 aggregates,
301 monotonic: _,
302 expected_group_size: _,
303 } => {
304 let mut retain = Vec::new();
305 let mut push_down = Vec::new();
306 for predicate in predicates.drain(..) {
307 // Do not push down literal errors unless it is only errors.
308 if !predicate.is_literal_err() || all_errors {
309 let mut supported = true;
310 let mut new_predicate = predicate.clone();
311 new_predicate.visit_pre(|e| {
312 if let MirScalarExpr::Column(c, _) = e {
313 if *c >= group_key.len() {
314 supported = false;
315 }
316 }
317 });
318 if supported {
319 new_predicate.visit_mut_post(&mut |e| {
320 if let MirScalarExpr::Column(i, _) = e {
321 *e = group_key[*i].clone();
322 }
323 })?;
324 push_down.push(new_predicate);
325 } else if let MirScalarExpr::Column(col, _) = &predicate {
326 if *col == group_key.len()
327 && aggregates.len() == 1
328 && aggregates[0].func == AggregateFunc::Any
329 {
330 push_down.push(aggregates[0].expr.clone());
331 aggregates[0].expr = MirScalarExpr::literal_ok(
332 Datum::True,
333 SqlScalarType::Bool,
334 );
335 } else {
336 retain.push(predicate);
337 }
338 } else {
339 retain.push(predicate);
340 }
341 } else {
342 retain.push(predicate);
343 }
344 }
345
346 if !push_down.is_empty() {
347 *inner = Box::new(inner.take_dangerous().filter(push_down));
348 }
349 self.action(inner, get_predicates)?;
350
351 // remove all predicates that were pushed down from the current Filter node
352 std::mem::swap(&mut retain, predicates);
353 }
354 MirRelationExpr::TopK {
355 input,
356 group_key,
357 order_key: _,
358 limit,
359 offset: _,
360 monotonic: _,
361 expected_group_size: _,
362 } => {
363 let mut retain = Vec::new();
364 let mut push_down = Vec::new();
365
366 let mut support = BTreeSet::new();
367 support.extend(group_key.iter().cloned());
368 if let Some(limit) = limit {
369 // Strictly speaking not needed because the
370 // `limit` support should be a subset of the
371 // `group_key` support, but we don't want to
372 // take this for granted here.
373 limit.support_into(&mut support);
374 }
375
376 for predicate in predicates.drain(..) {
377 // Do not push down literal errors unless it is only errors.
378 if (!predicate.is_literal_err() || all_errors)
379 && predicate.support().is_subset(&support)
380 {
381 push_down.push(predicate);
382 } else {
383 retain.push(predicate);
384 }
385 }
386
387 // remove all predicates that were pushed down from the current Filter node
388 std::mem::swap(&mut retain, predicates);
389
390 if !push_down.is_empty() {
391 *input = Box::new(input.take_dangerous().filter(push_down));
392 }
393
394 self.action(input, get_predicates)?;
395 }
396 MirRelationExpr::Threshold { input } => {
397 let predicates = std::mem::take(predicates);
398 *relation = input.take_dangerous().filter(predicates).threshold();
399 self.action(relation, get_predicates)?;
400 }
401 MirRelationExpr::Project { input, outputs } => {
402 let predicates = predicates.drain(..).map(|mut predicate| {
403 predicate.permute(outputs);
404 predicate
405 });
406 *relation = input
407 .take_dangerous()
408 .filter(predicates)
409 .project(outputs.clone());
410
411 self.action(relation, get_predicates)?;
412 }
413 MirRelationExpr::Filter {
414 input,
415 predicates: predicates2,
416 } => {
417 *relation = input
418 .take_dangerous()
419 .filter(predicates.clone().into_iter().chain(predicates2.clone()));
420 self.action(relation, get_predicates)?;
421 }
422 MirRelationExpr::Map { input, scalars } => {
423 let (retained, pushdown) = Self::push_filters_through_map(
424 scalars,
425 predicates,
426 input.arity(),
427 all_errors,
428 )?;
429 let scalars = std::mem::take(scalars);
430 let mut result = input.take_dangerous();
431 if !pushdown.is_empty() {
432 result = result.filter(pushdown);
433 }
434 self.action(&mut result, get_predicates)?;
435 result = result.map(scalars);
436 if !retained.is_empty() {
437 result = result.filter(retained);
438 }
439 *relation = result;
440 }
441 MirRelationExpr::FlatMap { input, .. } => {
442 let (mut retained, pushdown) =
443 Self::push_filters_through_flat_map(predicates, input.arity());
444
445 // remove all predicates that were pushed down from the current Filter node
446 std::mem::swap(&mut retained, predicates);
447
448 if !pushdown.is_empty() {
449 // put the filter on top of the input
450 **input = input.take_dangerous().filter(pushdown);
451 }
452
453 // ... and keep pushing predicates down
454 self.action(input, get_predicates)?;
455 }
456 MirRelationExpr::Union { base, inputs } => {
457 let predicates = std::mem::take(predicates);
458 *base = Box::new(base.take_dangerous().filter(predicates.clone()));
459 self.action(base, get_predicates)?;
460 for input in inputs {
461 *input = input.take_dangerous().filter(predicates.clone());
462 self.action(input, get_predicates)?;
463 }
464 }
465 MirRelationExpr::Negate { input } => {
466 // Don't push literal errors past a Negate. The problem is that it's
467 // hard to appropriately reflect the negation in the error stream:
468 // - If we don't negate, then errors that should cancel out will not
469 // cancel out. For example, see
470 // https://github.com/MaterializeInc/database-issues/issues/5691
471 // - If we negate, then unrelated errors might cancel out. E.g., there
472 // might be a division-by-0 in both inputs to an EXCEPT ALL, but
473 // on different input data. These shouldn't cancel out.
474 let (retained, pushdown): (Vec<_>, Vec<_>) = std::mem::take(predicates)
475 .into_iter()
476 .partition(|p| p.is_literal_err());
477 let mut result = input.take_dangerous();
478 if !pushdown.is_empty() {
479 result = result.filter(pushdown);
480 }
481 self.action(&mut result, get_predicates)?;
482 result = result.negate();
483 if !retained.is_empty() {
484 result = result.filter(retained);
485 }
486 *relation = result;
487 }
488 x => {
489 x.try_visit_mut_children(|e| self.action(e, get_predicates))?;
490 }
491 }
492
493 // remove empty filters (junk by-product of the actual transform)
494 match relation {
495 MirRelationExpr::Filter { predicates, input } if predicates.is_empty() => {
496 *relation = input.take_dangerous();
497 }
498 _ => {}
499 }
500
501 Ok(())
502 }
503 MirRelationExpr::Get { id, .. } => {
504 // Purge all predicates associated with the id.
505 get_predicates
506 .entry(*id)
507 .or_insert_with(BTreeSet::new)
508 .clear();
509
510 Ok(())
511 }
512 MirRelationExpr::Let { id, body, value } => {
513 // Push predicates and collect intersection at `Get`s.
514 self.action(body, get_predicates)?;
515
516 // `get_predicates` should now contain the intersection
517 // of predicates at each *use* of the binding. If it is
518 // non-empty, we can move those predicates to the value.
519 Self::push_into_let_binding(get_predicates, id, value, &mut [body]);
520
521 // Continue recursively on the value.
522 self.action(value, get_predicates)
523 }
524 MirRelationExpr::LetRec {
525 ids,
526 values,
527 limits: _,
528 body,
529 } => {
530 // Note: This could be extended to be able to do a little more pushdowns, see
531 // https://github.com/MaterializeInc/database-issues/issues/5336#issuecomment-1477588262
532
533 // Pre-compute which Ids are used across iterations
534 let ids_used_across_iterations = MirRelationExpr::recursive_ids(ids, values);
535
536 // Predicate pushdown within the body
537 self.action(body, get_predicates)?;
538
539 // `users` will be the body plus the values of those bindings that we have seen
540 // so far, while going one-by-one through the list of bindings backwards.
541 // `users` contains those expressions from which we harvested `get_predicates`,
542 // and therefore we should attend to all of these expressions when pushing down
543 // a predicate into a Let binding.
544 let mut users = vec![&mut **body];
545 for (id, value) in ids.iter_mut().rev().zip_eq(values.into_iter().rev()) {
546 // Predicate pushdown from Gets in `users` into the value of a Let binding
547 //
548 // For now, we simply always avoid pushing into a Let binding that is
549 // referenced across iterations to avoid soundness problems and infinite
550 // pushdowns.
551 //
552 // Note that `push_into_let_binding` makes a further check based on
553 // `get_predicates`: We push a predicate into the value of a binding, only
554 // if all Gets of this Id have this same predicate on top of them.
555 if !ids_used_across_iterations.contains(id) {
556 Self::push_into_let_binding(get_predicates, id, value, &mut users);
557 }
558
559 // Predicate pushdown within a binding
560 self.action(value, get_predicates)?;
561
562 users.push(value);
563 }
564
565 Ok(())
566 }
567 MirRelationExpr::Join {
568 inputs,
569 equivalences,
570 ..
571 } => {
572 // The goal is to push
573 // 1) equivalences of the form `expr = <runtime constant>`, where `expr`
574 // comes from a single input.
575 // 2) equivalences of the form `expr1 = expr2`, where both
576 // expressions come from the same single input.
577 let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
578 mz_expr::canonicalize::canonicalize_equivalences(
579 equivalences,
580 input_types.iter().map(|t| &t.column_types),
581 );
582
583 let input_mapper = mz_expr::JoinInputMapper::new_from_input_types(&input_types);
584 // Predicates to push at each input, and to lift out the join.
585 let mut push_downs = vec![Vec::new(); inputs.len()];
586
587 for equivalence_pos in 0..equivalences.len() {
588 // Case 1: there are more than one literal in the
589 // equivalence class. Because of equivalences have been
590 // dedupped, this means that everything in the equivalence
591 // class must be equal to two different literals, so the
592 // entire relation zeroes out
593 if equivalences[equivalence_pos]
594 .iter()
595 .filter(|expr| expr.is_literal())
596 .count()
597 > 1
598 {
599 relation.take_safely(Some(relation.typ_with_input_types(&input_types)));
600 return Ok(());
601 }
602
603 let runtime_constants = equivalences[equivalence_pos]
604 .iter()
605 .filter(|expr| expr.support().is_empty())
606 .cloned()
607 .collect::<Vec<_>>();
608 if !runtime_constants.is_empty() {
609 // Case 2: There is at least one runtime constant the equivalence class
610 let gen_literal_equality_preds = |expr: MirScalarExpr| {
611 let mut equality_preds = Vec::new();
612 for constant in runtime_constants.iter() {
613 let pred = if constant.is_literal_null() {
614 MirScalarExpr::CallUnary {
615 func: mz_expr::UnaryFunc::IsNull(func::IsNull),
616 expr: Box::new(expr.clone()),
617 }
618 } else {
619 MirScalarExpr::CallBinary {
620 func: mz_expr::func::Eq.into(),
621 expr1: Box::new(expr.clone()),
622 expr2: Box::new(constant.clone()),
623 }
624 };
625 equality_preds.push(pred);
626 }
627 equality_preds
628 };
629
630 // Find all single input expressions in the equivalence
631 // class and collect (position within the equivalence class,
632 // input the expression belongs to, localized version of the
633 // expression).
634 let mut single_input_exprs = equivalences[equivalence_pos]
635 .iter()
636 .enumerate()
637 .filter_map(|(pos, e)| {
638 let mut inputs = input_mapper.lookup_inputs(e);
639 if let Some(input) = inputs.next() {
640 if inputs.next().is_none() {
641 return Some((
642 pos,
643 input,
644 input_mapper.map_expr_to_local(e.clone()),
645 ));
646 }
647 }
648 None
649 })
650 .collect::<Vec<_>>();
651
652 // For every single-input expression `expr`, we can push
653 // down `expr = <runtime constant>` and remove `expr` from the
654 // equivalence class.
655 for (expr_pos, input, expr) in single_input_exprs.drain(..).rev() {
656 push_downs[input].extend(gen_literal_equality_preds(expr));
657 equivalences[equivalence_pos].remove(expr_pos);
658 }
659
660 // If none of the expressions in the equivalence depend on input
661 // columns and equality predicates with them are pushed down,
662 // we can safely remove them from the equivalence.
663 // TODO: we could probably push equality predicates among the
664 // remaining constants to all join inputs to prevent any computation
665 // from happening until the condition is satisfied.
666 if equivalences[equivalence_pos]
667 .iter()
668 .all(|e| e.support().is_empty())
669 && push_downs.iter().any(|p| !p.is_empty())
670 {
671 equivalences[equivalence_pos].clear();
672 }
673 } else {
674 // Case 3: There are no constants in the equivalence
675 // class. Push a predicate for every pair of expressions
676 // in the equivalence that either belong to a single
677 // input or can be localized to a given input through
678 // the rest of equivalences.
679 let mut to_remove = Vec::new();
680 for input in 0..inputs.len() {
681 // Vector of pairs (position within the equivalence, localized
682 // expression). The position is None for expressions derived through
683 // other equivalences.
684 let localized = equivalences[equivalence_pos]
685 .iter()
686 .enumerate()
687 .filter_map(|(pos, expr)| {
688 if let MirScalarExpr::Column(col_pos, _) = &expr {
689 let local_col =
690 input_mapper.map_column_to_local(*col_pos);
691 if input == local_col.1 {
692 // TODO(mgree) !!! is it safe to propagate the name here?
693 return Some((
694 Some(pos),
695 MirScalarExpr::column(local_col.0),
696 ));
697 } else {
698 return None;
699 }
700 }
701 let mut inputs = input_mapper.lookup_inputs(expr);
702 if let Some(single_input) = inputs.next() {
703 if input == single_input && inputs.next().is_none() {
704 return Some((
705 Some(pos),
706 input_mapper.map_expr_to_local(expr.clone()),
707 ));
708 }
709 }
710 // Equivalences not including the current expression
711 let mut other_equivalences = equivalences.clone();
712 other_equivalences[equivalence_pos].remove(pos);
713 let mut localized = expr.clone();
714 if input_mapper.try_localize_to_input_with_bound_expr(
715 &mut localized,
716 input,
717 &other_equivalences[..],
718 ) {
719 Some((None, localized))
720 } else {
721 None
722 }
723 })
724 .collect::<Vec<_>>();
725
726 // If there are at least 2 expression in the equivalence that
727 // can be localized to the same input, push all combinations
728 // of them to the input.
729 if localized.len() > 1 {
730 for mut pair in
731 localized.iter().map(|(_, expr)| expr).combinations(2)
732 {
733 let expr1 = pair.pop().unwrap();
734 let expr2 = pair.pop().unwrap();
735
736 push_downs[input].push(
737 MirScalarExpr::CallBinary {
738 func: func::Eq.into(),
739 expr1: Box::new(expr2.clone()),
740 expr2: Box::new(expr1.clone()),
741 }
742 .or(expr2
743 .clone()
744 .call_is_null()
745 .and(expr1.clone().call_is_null())),
746 );
747 }
748
749 if localized.len() == equivalences[equivalence_pos].len() {
750 // The equivalence is either a single input one or fully localizable
751 // to a single input through other equivalences, so it can be removed
752 // completely without introducing any new cross join.
753 to_remove.extend(0..equivalences[equivalence_pos].len());
754 } else {
755 // Leave an expression from this input in the equivalence to avoid
756 // cross joins
757 to_remove.extend(
758 localized.iter().filter_map(|(pos, _)| *pos).skip(1),
759 );
760 }
761 }
762 }
763
764 // Remove expressions that were pushed down to at least one input
765 to_remove.sort();
766 to_remove.dedup();
767 for pos in to_remove.iter().rev() {
768 equivalences[equivalence_pos].remove(*pos);
769 }
770 };
771 }
772
773 mz_expr::canonicalize::canonicalize_equivalences(
774 equivalences,
775 input_types.iter().map(|t| &t.column_types),
776 );
777
778 Self::update_join_inputs_with_push_downs(inputs, push_downs);
779
780 // Recursively descend on each of the inputs.
781 for input in inputs.iter_mut() {
782 self.action(input, get_predicates)?;
783 }
784
785 Ok(())
786 }
787 x => {
788 // Recursively descend.
789 x.try_visit_mut_children(|e| self.action(e, get_predicates))
790 }
791 }
792 })
793 }
794
795 fn update_join_inputs_with_push_downs(
796 inputs: &mut Vec<MirRelationExpr>,
797 push_downs: Vec<Vec<MirScalarExpr>>,
798 ) {
799 let new_inputs = inputs
800 .drain(..)
801 .zip_eq(push_downs)
802 .map(|(input, push_down)| {
803 if !push_down.is_empty() {
804 input.filter(push_down)
805 } else {
806 input
807 }
808 })
809 .collect();
810 *inputs = new_inputs;
811 }
812
813 // Checks `get_predicates` to see whether we can push a predicate into the Let binding given
814 // by `id` and `value`.
815 // `users` is the list of those expressions from which we will need to remove a predicate that
816 // is being pushed.
817 fn push_into_let_binding(
818 get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
819 id: &LocalId,
820 value: &mut MirRelationExpr,
821 users: &mut [&mut MirRelationExpr],
822 ) {
823 if let Some(list) = get_predicates.remove(&Id::Local(*id)) {
824 if !list.is_empty() {
825 // Remove the predicates in `list` from the users.
826 for user in users {
827 user.visit_pre_mut(|e| {
828 if let MirRelationExpr::Filter { input, predicates } = e {
829 if let MirRelationExpr::Get { id: get_id, .. } = **input {
830 if get_id == Id::Local(*id) {
831 predicates.retain(|p| !list.contains(p));
832 }
833 }
834 }
835 });
836 }
837 // Apply the predicates in `list` to value. Canonicalize
838 // `list` so that plans are always deterministic.
839 let mut list = list.into_iter().collect::<Vec<_>>();
840 mz_expr::canonicalize::canonicalize_predicates(
841 &mut list,
842 &value.typ().column_types,
843 );
844 *value = value.take_dangerous().filter(list);
845 }
846 }
847 }
848
849 /// Returns `(<predicates to retain>, <predicates to push at each input>)`.
850 pub fn push_filters_through_join(
851 input_mapper: &JoinInputMapper,
852 equivalences: &Vec<Vec<MirScalarExpr>>,
853 mut predicates: Vec<MirScalarExpr>,
854 ) -> (Vec<MirScalarExpr>, Vec<Vec<MirScalarExpr>>) {
855 let mut push_downs = vec![Vec::new(); input_mapper.total_inputs()];
856 let mut retain = Vec::new();
857
858 for predicate in predicates.drain(..) {
859 // Track if the predicate has been pushed to at least one input.
860 let mut pushed = false;
861 // For each input, try and see if the join
862 // equivalences allow the predicate to be rewritten
863 // in terms of only columns from that input.
864 for (index, push_down) in push_downs.iter_mut().enumerate() {
865 #[allow(deprecated)] // TODO: use `might_error` if possible.
866 if predicate.is_literal_err() || predicate.contains_error_if_null() {
867 // Do nothing. We don't push down literal errors,
868 // as we can't know the join will be non-empty.
869 //
870 // We also don't want to push anything that involves `error_if_null`. This is
871 // for the same reason why in theory we shouldn't really push anything that can
872 // error, assuming that we want to preserve error semantics. (Because we would
873 // create a spurious error if some other Join input ends up empty.) We can't fix
874 // this problem in general (as we can't just not push anything that might
875 // error), but we decided to fix the specific problem instance involving
876 // `error_if_null`, because it was very painful:
877 // <https://github.com/MaterializeInc/database-issues/issues/6258>
878 } else {
879 let mut localized = predicate.clone();
880 if input_mapper.try_localize_to_input_with_bound_expr(
881 &mut localized,
882 index,
883 equivalences,
884 ) {
885 push_down.push(localized);
886 pushed = true;
887 } else if let Some(consequence) = input_mapper
888 // (`consequence_for_input` assumes that
889 // `try_localize_to_input_with_bound_expr` has already
890 // been called on `localized`.)
891 .consequence_for_input(&localized, index)
892 {
893 push_down.push(consequence);
894 // We don't set `pushed` here! We want to retain the
895 // predicate, because we only pushed a consequence of
896 // it, but not the full predicate.
897 }
898 }
899 }
900
901 if !pushed {
902 retain.push(predicate);
903 }
904 }
905
906 (retain, push_downs)
907 }
908
909 /// Computes "safe" predicates to push through a Map.
910 ///
911 /// In the case of a Filter { Map {...} }, we can always push down the Filter
912 /// by inlining expressions from the Map. We don't want to do this in general,
913 /// however, since general inlining can result in exponential blowup in the size
914 /// of expressions, so we only do this in the case where the size after inlining
915 /// is below a certain limit.
916 ///
917 /// Returns the predicates that can be pushed down, followed by ones that cannot.
918 pub fn push_filters_through_map(
919 map_exprs: &Vec<MirScalarExpr>,
920 predicates: &mut Vec<MirScalarExpr>,
921 input_arity: usize,
922 all_errors: bool,
923 ) -> Result<(Vec<MirScalarExpr>, Vec<MirScalarExpr>), TransformError> {
924 let mut pushdown = Vec::new();
925 let mut retained = Vec::new();
926 for predicate in predicates.drain(..) {
927 // We don't push down literal errors, unless all predicates are.
928 if !predicate.is_literal_err() || all_errors {
929 // Consider inlining Map expressions.
930 if let Some(cleaned) =
931 Self::inline_if_not_too_big(&predicate, input_arity, map_exprs)?
932 {
933 pushdown.push(cleaned);
934 } else {
935 retained.push(predicate);
936 }
937 } else {
938 retained.push(predicate);
939 }
940 }
941 Ok((retained, pushdown))
942 }
943
944 /// This fn should be called with a Filter `expr` that is after a Map. `input_arity` is the
945 /// arity of the input of the Map. This fn eliminates such column refs in `expr` that refer not
946 /// to a column in the input of the Map, but to a column that is created by the Map. It does
947 /// this by transitively inlining Map expressions until no such expression remains that points
948 /// to a Map expression. The return value is the cleaned up expression. The fn bails out with a
949 /// None if the resulting expression would be made too big by the inlinings.
950 ///
951 /// OOO: (Optimizer Optimization Opportunity) This function might do work proportional to the
952 /// total size of the Map expressions. We call this function for each predicate above the Map,
953 /// which will be kind of quadratic, i.e., if there are many predicates and a big Map, then this
954 /// will be slow. We could instead pass a vector of Map expressions and call this fn only once.
955 /// The only downside would be that then the inlining limit being hit in the middle part of this
956 /// function would prevent us from inlining any predicates, even ones that wouldn't hit the
957 /// inlining limit if considered on their own.
958 fn inline_if_not_too_big(
959 expr: &MirScalarExpr,
960 input_arity: usize,
961 map_exprs: &Vec<MirScalarExpr>,
962 ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
963 let size_limit = 1000;
964
965 // Transitively determine the support of `expr` produced by `map_exprs`
966 // that needs to be inlined.
967 let cols_to_inline = {
968 let mut support = BTreeSet::new();
969
970 // Seed with `map_exprs` support in `expr`.
971 expr.visit_pre(|e| {
972 if let MirScalarExpr::Column(c, _) = e {
973 if *c >= input_arity {
974 support.insert(*c);
975 }
976 }
977 });
978
979 // Compute transitive closure of supports in `map_exprs`.
980 let mut workset = support.iter().cloned().collect::<Vec<_>>();
981 let mut buffer = vec![];
982 while !workset.is_empty() {
983 // Swap the (empty) `drained` buffer with the `workset`.
984 std::mem::swap(&mut workset, &mut buffer);
985 // Drain the `buffer` and update `support` and `workset`.
986 for c in buffer.drain(..) {
987 map_exprs[c - input_arity].visit_pre(|e| {
988 if let MirScalarExpr::Column(c, _) = e {
989 if *c >= input_arity {
990 if support.insert(*c) {
991 workset.push(*c);
992 }
993 }
994 }
995 });
996 }
997 }
998 support
999 };
1000
1001 let mut inlined = BTreeMap::<usize, (MirScalarExpr, usize)>::new();
1002 // Populate the memo table in ascending column order (which respects the
1003 // dependency order of `map_exprs` references). Break early if memoization
1004 // fails for one of the columns in `cols_to_inline`.
1005 for c in cols_to_inline.iter() {
1006 let mut new_expr = map_exprs[*c - input_arity].clone();
1007 let mut new_size = 0;
1008 new_expr.visit_mut_post(&mut |expr| {
1009 new_size += 1;
1010 if let MirScalarExpr::Column(c, _) = expr {
1011 if *c >= input_arity && new_size <= size_limit {
1012 // (inlined[c] is safe, because we proceed in column order, and we break out
1013 // of the loop when we stop inserting into memo.)
1014 let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1015 *expr = m_expr.clone();
1016 new_size += m_size - 1; // Adjust for the +1 above.
1017 }
1018 }
1019 })?;
1020
1021 if new_size <= size_limit {
1022 inlined.insert(*c, (new_expr, new_size));
1023 } else {
1024 break;
1025 }
1026 }
1027
1028 // Try to resolve expr against the memo table.
1029 if inlined.len() < cols_to_inline.len() {
1030 Ok(None) // We couldn't memoize all map expressions within the given limit.
1031 } else {
1032 let mut new_expr = expr.clone();
1033 let mut new_size = 0;
1034 new_expr.visit_mut_post(&mut |expr| {
1035 new_size += 1;
1036 if let MirScalarExpr::Column(c, _) = expr {
1037 if *c >= input_arity && new_size <= size_limit {
1038 // (inlined[c] is safe because of the outer if condition.)
1039 let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1040 *expr = m_expr.clone();
1041 new_size += m_size - 1; // Adjust for the +1 above.
1042 }
1043 }
1044 })?;
1045
1046 soft_assert_eq_no_log!(new_size, new_expr.size());
1047 if new_size <= size_limit {
1048 Ok(Some(new_expr)) // We managed to stay within the limit.
1049 } else {
1050 Ok(None) // Limit exceeded.
1051 }
1052 }
1053 }
1054 // fn inline_if_not_too_big(
1055 // expr: &MirScalarExpr,
1056 // input_arity: usize,
1057 // map_exprs: &Vec<MirScalarExpr>,
1058 // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1059 // let size_limit = 1000;
1060 // // Memoize cleaned up versions of Map expressions. (Not necessarily all the Map expressions
1061 // // will be involved.)
1062 // let mut memo: BTreeMap<MirScalarExpr, MirScalarExpr> = BTreeMap::new();
1063 // fn rec(
1064 // expr: &MirScalarExpr,
1065 // input_arity: usize,
1066 // map_exprs: &Vec<MirScalarExpr>,
1067 // memo: &mut BTreeMap<MirScalarExpr, MirScalarExpr>,
1068 // size_limit: usize,
1069 // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1070 // // (We can't use Entry::or_insert_with, because the closure would need to be fallible.
1071 // // We also can't manually match on the result of memo.entry, because that holds a
1072 // // borrow of memo, but we need to pass memo to the recursive call in the middle.)
1073 // match memo.get(expr) {
1074 // Some(memoized_result) => Ok(Some(memoized_result.clone())),
1075 // None => {
1076 // let mut expr_size = expr.size()?;
1077 // let mut cleaned_expr = expr.clone();
1078 // let mut bail = false;
1079 // cleaned_expr.try_visit_mut_post(&mut |expr| {
1080 // Ok(if !bail {
1081 // match expr {
1082 // MirScalarExpr::Column(col) => {
1083 // if *col >= input_arity {
1084 // let to_inline = rec(
1085 // &map_exprs[*col - input_arity],
1086 // input_arity,
1087 // map_exprs,
1088 // memo,
1089 // size_limit,
1090 // )?;
1091 // if let Some(to_inline) = to_inline {
1092 // // The `-1` is because the expression that we are
1093 // // replacing has a size of 1.
1094 // expr_size += to_inline.size()? - 1;
1095 // *expr = to_inline;
1096 // if expr_size > size_limit {
1097 // bail = true;
1098 // }
1099 // } else {
1100 // bail = true;
1101 // }
1102 // }
1103 // }
1104 // _ => (),
1105 // }
1106 // })
1107 // })?;
1108 // soft_assert_eq!(cleaned_expr.size()?, expr_size);
1109 // if !bail {
1110 // memo.insert(expr.clone(), cleaned_expr.clone());
1111 // Ok(Some(cleaned_expr))
1112 // } else {
1113 // Ok(None)
1114 // }
1115 // }
1116 // }
1117 // }
1118 // rec(expr, input_arity, map_exprs, &mut memo, size_limit)
1119 // }
1120
1121 /// Computes "safe" predicate to push through a FlatMap.
1122 ///
1123 /// In the case of a Filter { FlatMap {...} }, we want to push through all predicates
1124 /// that (1) are not literal errors and (2) have support exclusively in the columns
1125 /// provided by the FlatMap input.
1126 ///
1127 /// Returns the predicates that can be pushed down, followed by ones that cannot.
1128 fn push_filters_through_flat_map(
1129 predicates: &mut Vec<MirScalarExpr>,
1130 input_arity: usize,
1131 ) -> (Vec<MirScalarExpr>, Vec<MirScalarExpr>) {
1132 let mut pushdown = Vec::new();
1133 let mut retained = Vec::new();
1134 for predicate in predicates.drain(..) {
1135 // First, check if we can push this predicate down. We can do so if and only if:
1136 // (1) the predicate is not a literal error, and
1137 // (2) each column it references is from the input.
1138 if (!predicate.is_literal_err()) && predicate.support().iter().all(|c| *c < input_arity)
1139 {
1140 pushdown.push(predicate);
1141 } else {
1142 retained.push(predicate);
1143 }
1144 }
1145 (retained, pushdown)
1146 }
1147
1148 /// If `s` is of the form
1149 /// `(isnull(expr1) && isnull(expr2)) || (expr1 = expr2)`, or
1150 /// `(decompose_is_null(expr1) && decompose_is_null(expr2)) || (expr1 = expr2)`,
1151 /// extract `expr1` and `expr2`.
1152 fn extract_equal_or_both_null(
1153 s: &mut MirScalarExpr,
1154 column_types: &[SqlColumnType],
1155 ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1156 if let MirScalarExpr::CallVariadic {
1157 func: VariadicFunc::Or,
1158 exprs,
1159 } = s
1160 {
1161 if let &[ref or_lhs, ref or_rhs] = &**exprs {
1162 // Check both orders of operands of the OR
1163 return Self::extract_equal_or_both_null_inner(or_lhs, or_rhs, column_types)
1164 .or_else(|| {
1165 Self::extract_equal_or_both_null_inner(or_rhs, or_lhs, column_types)
1166 });
1167 }
1168 }
1169 None
1170 }
1171
1172 fn extract_equal_or_both_null_inner(
1173 or_arg1: &MirScalarExpr,
1174 or_arg2: &MirScalarExpr,
1175 column_types: &[SqlColumnType],
1176 ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1177 use mz_expr::BinaryFunc;
1178 if let MirScalarExpr::CallBinary {
1179 func: BinaryFunc::Eq(_),
1180 expr1: eq_lhs,
1181 expr2: eq_rhs,
1182 } = &or_arg2
1183 {
1184 let isnull1 = eq_lhs.clone().call_is_null();
1185 let isnull2 = eq_rhs.clone().call_is_null();
1186 let both_null = MirScalarExpr::CallVariadic {
1187 func: VariadicFunc::And,
1188 exprs: vec![isnull1, isnull2],
1189 };
1190
1191 if Self::extract_reduced_conjunction_terms(both_null, column_types)
1192 == Self::extract_reduced_conjunction_terms(or_arg1.clone(), column_types)
1193 {
1194 return Some(((**eq_lhs).clone(), (**eq_rhs).clone()));
1195 }
1196 }
1197 None
1198 }
1199
1200 /// Reduces the given expression and returns its AND-ed terms.
1201 fn extract_reduced_conjunction_terms(
1202 mut s: MirScalarExpr,
1203 column_types: &[SqlColumnType],
1204 ) -> Vec<MirScalarExpr> {
1205 s.reduce(column_types);
1206
1207 if let MirScalarExpr::CallVariadic {
1208 func: VariadicFunc::And,
1209 exprs,
1210 } = s
1211 {
1212 exprs
1213 } else {
1214 vec![s]
1215 }
1216 }
1217}