Skip to main content

mz_transform/movement/
projection_pushdown.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes column removal down through other operators.
11//!
12//! This action improves the quality of the query by
13//! reducing the width of data in the dataflow. It determines the unique
14//! columns an expression depends on, and pushes a projection onto only
15//! those columns down through child operators.
16//!
17//! A `MirRelationExpr::Project` node is actually three transformations in one.
18//! 1) Projection - removes columns.
19//! 2) Permutation - reorders columns.
20//! 3) Repetition - duplicates columns.
21//!
22//! This action handles these three transformations like so:
23//! 1) Projections are pushed as far down as possible.
24//! 2) Permutations are pushed as far down as is convenient.
25//! 3) Repetitions are not pushed down at all.
26//!
27//! Some comments have been inherited from the `Demand` transform.
28//!
29//! Note that this transform is one that can operate across views in a dataflow
30//! and thus currently exists outside of both the physical and logical
31//! optimizers.
32
33use std::collections::{BTreeMap, BTreeSet};
34
35use itertools::{Itertools, zip_eq};
36use mz_expr::{
37    Id, JoinImplementation, JoinInputMapper, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
38};
39use mz_ore::assert_none;
40use mz_ore::stack::{CheckedRecursion, RecursionGuard};
41use mz_repr::ReprRelationType;
42
43use crate::{TransformCtx, TransformError};
44
45/// Pushes projections down through other operators.
46#[derive(Debug)]
47pub struct ProjectionPushdown {
48    recursion_guard: RecursionGuard,
49    include_joins: bool,
50}
51
52impl Default for ProjectionPushdown {
53    fn default() -> Self {
54        Self {
55            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
56            include_joins: true,
57        }
58    }
59}
60
61impl ProjectionPushdown {
62    /// Construct a `ProjectionPushdown` that does not push projections through joins (but does
63    /// descend into join inputs).
64    pub fn skip_joins() -> Self {
65        Self {
66            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
67            include_joins: false,
68        }
69    }
70}
71
72impl CheckedRecursion for ProjectionPushdown {
73    fn recursion_guard(&self) -> &RecursionGuard {
74        &self.recursion_guard
75    }
76}
77
78impl crate::Transform for ProjectionPushdown {
79    fn name(&self) -> &'static str {
80        "ProjectionPushdown"
81    }
82
83    // This method is only used during unit testing.
84    #[mz_ore::instrument(
85        target = "optimizer",
86        level = "debug",
87        fields(path.segment = "projection_pushdown")
88    )]
89    fn actually_perform_transform(
90        &self,
91        relation: &mut MirRelationExpr,
92        _: &mut TransformCtx,
93    ) -> Result<(), crate::TransformError> {
94        let result = self.action(
95            relation,
96            &(0..relation.arity()).collect(),
97            &mut BTreeMap::new(),
98        );
99        mz_repr::explain::trace_plan(&*relation);
100        result
101    }
102}
103
104impl ProjectionPushdown {
105    /// Pushes the `desired_projection` down through `relation`.
106    ///
107    /// This action transforms `relation` to a `MirRelationExpr` equivalent to
108    /// `relation.project(desired_projection)`.
109    ///
110    /// `desired_projection` is expected to consist of unique columns.
111    pub fn action(
112        &self,
113        relation: &mut MirRelationExpr,
114        desired_projection: &Vec<usize>,
115        gets: &mut BTreeMap<Id, BTreeSet<usize>>,
116    ) -> Result<(), TransformError> {
117        self.checked_recur(|_| {
118            // First, try to push the desired projection down through `relation`.
119            // In the process `relation` is transformed to a `MirRelationExpr`
120            // equivalent to `relation.project(actual_projection)`.
121            // There are three reasons why `actual_projection` may differ from
122            // `desired_projection`:
123            // 1) `relation` may need one or more columns that is not contained in
124            //    `desired_projection`.
125            // 2) `relation` may not be able to accommodate certain permutations.
126            //    For example, `MirRelationExpr::Map` always appends all
127            //    newly-created columns to the end.
128            // 3) Nothing can be pushed through a leaf node. If `relation` is a leaf
129            //    node, `actual_projection` will always be `(0..relation.arity())`.
130            // Then, if `actual_projection` and `desired_projection` differ, we will
131            // add a project around `relation`.
132            let actual_projection = match relation {
133                MirRelationExpr::Constant { .. } => (0..relation.arity()).collect(),
134                MirRelationExpr::Get { id, .. } => {
135                    gets.entry(*id)
136                        .or_insert_with(BTreeSet::new)
137                        .extend(desired_projection.iter().cloned());
138                    (0..relation.arity()).collect()
139                }
140                MirRelationExpr::Let { id, value, body } => {
141                    // Let harvests any requirements of get from its body,
142                    // and pushes the sorted union of the requirements at its value.
143                    let id = Id::Local(*id);
144                    let prior = gets.insert(id, BTreeSet::new());
145                    self.action(body, desired_projection, gets)?;
146                    let desired_value_projection = gets.remove(&id).unwrap();
147                    if let Some(prior) = prior {
148                        gets.insert(id, prior);
149                    }
150                    let desired_value_projection =
151                        desired_value_projection.into_iter().collect::<Vec<_>>();
152                    self.action(value, &desired_value_projection, gets)?;
153                    let new_type = value.typ();
154                    self.update_projection_around_get(
155                        body,
156                        &BTreeMap::from_iter(std::iter::once((
157                            id,
158                            (desired_value_projection, new_type),
159                        ))),
160                    );
161                    desired_projection.clone()
162                }
163                MirRelationExpr::LetRec {
164                    ids,
165                    values,
166                    limits: _,
167                    body,
168                } => {
169                    // Determine the recursive IDs in this LetRec binding.
170                    let rec_ids = MirRelationExpr::recursive_ids(ids, values);
171
172                    // Seed the gets map with empty demand for each non-recursive ID.
173                    for id in ids.iter().filter(|id| !rec_ids.contains(id)) {
174                        let prior = gets.insert(Id::Local(*id), BTreeSet::new());
175                        assert_none!(prior);
176                    }
177
178                    // Descend into the body with the supplied desired_projection.
179                    self.action(body, desired_projection, gets)?;
180                    // Descend into the values in reverse order.
181                    for (id, value) in zip_eq(ids.iter().rev(), values.iter_mut().rev()) {
182                        let desired_projection = if rec_ids.contains(id) {
183                            // For recursive IDs: request all columns.
184                            let columns = 0..value.arity();
185                            columns.collect::<Vec<_>>()
186                        } else {
187                            // For non-recursive IDs: request the gets entry.
188                            let columns = gets.get(&Id::Local(*id)).unwrap();
189                            columns.iter().cloned().collect::<Vec<_>>()
190                        };
191                        self.action(value, &desired_projection, gets)?;
192                    }
193
194                    // Update projections around gets of non-recursive IDs.
195                    let mut updates = BTreeMap::new();
196                    for (id, value) in zip_eq(ids.iter(), values.iter_mut()) {
197                        // Update the current value.
198                        self.update_projection_around_get(value, &updates);
199                        // If this is a non-recursive ID, add an entry to the
200                        // updates map for subsequent values and the body.
201                        if !rec_ids.contains(id) {
202                            let new_type = value.typ();
203                            let new_proj = {
204                                let columns = gets.remove(&Id::Local(*id)).unwrap();
205                                columns.iter().cloned().collect::<Vec<_>>()
206                            };
207                            updates.insert(Id::Local(*id), (new_proj, new_type));
208                        }
209                    }
210                    // Update the body.
211                    self.update_projection_around_get(body, &updates);
212
213                    // Remove the entries for all ids (don't restrict only to
214                    // non-recursive IDs here for better hygene).
215                    for id in ids.iter() {
216                        gets.remove(&Id::Local(*id));
217                    }
218
219                    // Return the desired projection (leads to a no-op in the
220                    // projection handling logic after this match statement).
221                    desired_projection.clone()
222                }
223                MirRelationExpr::Join {
224                    inputs,
225                    equivalences,
226                    implementation,
227                } if self.include_joins => {
228                    assert!(
229                        matches!(implementation, JoinImplementation::Unimplemented),
230                        "ProjectionPushdown can't deal with filled in join implementations. Turn off `include_joins` if you'd like to run it after `JoinImplementation`."
231                    );
232
233                    let input_mapper = JoinInputMapper::new(inputs);
234
235                    let mut columns_to_pushdown =
236                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
237                    // Each equivalence class imposes internal demand for columns.
238                    for equivalence in equivalences.iter() {
239                        for expr in equivalence.iter() {
240                            expr.support_into(&mut columns_to_pushdown);
241                        }
242                    }
243
244                    // Populate child demands from external and internal demands.
245                    let new_columns =
246                        input_mapper.split_column_set_by_input(columns_to_pushdown.iter());
247
248                    // Recursively indicate the requirements.
249                    for (input, inp_columns) in inputs.iter_mut().zip_eq(new_columns) {
250                        let inp_columns = inp_columns.into_iter().collect::<Vec<_>>();
251                        self.action(input, &inp_columns, gets)?;
252                    }
253
254                    reverse_permute(
255                        equivalences.iter_mut().flat_map(|e| e.iter_mut()),
256                        columns_to_pushdown.iter(),
257                    );
258
259                    columns_to_pushdown.into_iter().collect()
260                }
261                // Skip joins if `self.include_joins` is turned off.
262                MirRelationExpr::Join { inputs, equivalences: _, implementation: _ } => {
263                    let input_mapper = JoinInputMapper::new(inputs);
264
265                    // Include all columns.
266                    let columns_to_pushdown: Vec<_> = (0..input_mapper.total_columns()).collect();
267                    let child_columns =
268                        input_mapper.split_column_set_by_input(columns_to_pushdown.iter());
269
270                    // Recursively indicate the requirements.
271                    for (input, inp_columns) in inputs.iter_mut().zip_eq(child_columns) {
272                        let inp_columns = inp_columns.into_iter().collect::<Vec<_>>();
273                        self.action(input, &inp_columns, gets)?;
274                    }
275
276                    columns_to_pushdown.into_iter().collect()
277                }
278                MirRelationExpr::FlatMap { input, func, exprs } => {
279                    let inner_arity = input.arity();
280                    // A FlatMap which returns zero rows acts like a filter
281                    // so we always need to execute it
282                    let mut columns_to_pushdown =
283                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
284                    for expr in exprs.iter() {
285                        expr.support_into(&mut columns_to_pushdown);
286                    }
287                    columns_to_pushdown.retain(|c| *c < inner_arity);
288
289                    reverse_permute(exprs.iter_mut(), columns_to_pushdown.iter());
290                    let columns_to_pushdown = columns_to_pushdown.into_iter().collect::<Vec<_>>();
291                    self.action(input, &columns_to_pushdown, gets)?;
292                    // The actual projection always has the newly-created columns at
293                    // the end.
294                    let mut actual_projection = columns_to_pushdown;
295                    for c in 0..func.output_arity() {
296                        actual_projection.push(inner_arity + c);
297                    }
298                    actual_projection
299                }
300                MirRelationExpr::Filter { input, predicates } => {
301                    let mut columns_to_pushdown =
302                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
303                    for predicate in predicates.iter() {
304                        predicate.support_into(&mut columns_to_pushdown);
305                    }
306                    reverse_permute(predicates.iter_mut(), columns_to_pushdown.iter());
307                    let columns_to_pushdown = columns_to_pushdown.into_iter().collect::<Vec<_>>();
308                    self.action(input, &columns_to_pushdown, gets)?;
309                    columns_to_pushdown
310                }
311                MirRelationExpr::Project { input, outputs } => {
312                    // Combine `outputs` with `desired_projection`.
313                    *outputs = desired_projection.iter().map(|c| outputs[*c]).collect();
314
315                    let unique_outputs = outputs.iter().map(|i| *i).collect::<BTreeSet<_>>();
316                    if outputs.len() == unique_outputs.len() {
317                        // Push down the project as is.
318                        self.action(input, outputs, gets)?;
319                        *relation = input.take_dangerous();
320                    } else {
321                        // Push down only the unique elems in `outputs`.
322                        let columns_to_pushdown = unique_outputs.into_iter().collect::<Vec<_>>();
323                        reverse_permute_columns(outputs.iter_mut(), columns_to_pushdown.iter());
324                        self.action(input, &columns_to_pushdown, gets)?;
325                    }
326
327                    desired_projection.clone()
328                }
329                MirRelationExpr::Map { input, scalars } => {
330                    let arity = input.arity();
331                    // contains columns whose supports have yet to be explored
332                    let mut actual_projection =
333                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
334                    for (i, scalar) in scalars.iter().enumerate().rev() {
335                        if actual_projection.contains(&(i + arity)) {
336                            scalar.support_into(&mut actual_projection);
337                        }
338                    }
339                    *scalars = (0..scalars.len())
340                        .filter_map(|i| {
341                            if actual_projection.contains(&(i + arity)) {
342                                Some(scalars[i].clone())
343                            } else {
344                                None
345                            }
346                        })
347                        .collect::<Vec<_>>();
348                    reverse_permute(scalars.iter_mut(), actual_projection.iter());
349                    self.action(
350                        input,
351                        &actual_projection
352                            .iter()
353                            .filter(|c| **c < arity)
354                            .map(|c| *c)
355                            .collect(),
356                        gets,
357                    )?;
358                    actual_projection.into_iter().collect()
359                }
360                MirRelationExpr::Reduce {
361                    input,
362                    group_key,
363                    aggregates,
364                    monotonic: _,
365                    expected_group_size: _,
366                } => {
367                    let mut columns_to_pushdown = BTreeSet::new();
368                    // Group keys determine aggregation granularity and are
369                    // each crucial in determining aggregates and even the
370                    // multiplicities of other keys.
371                    for k in group_key.iter() {
372                        k.support_into(&mut columns_to_pushdown)
373                    }
374
375                    for index in (0..aggregates.len()).rev() {
376                        if !desired_projection.contains(&(group_key.len() + index)) {
377                            aggregates.remove(index);
378                        } else {
379                            // No obvious requirements on aggregate columns.
380                            // A "non-empty" requirement, I guess?
381                            aggregates[index]
382                                .expr
383                                .support_into(&mut columns_to_pushdown)
384                        }
385                    }
386
387                    reverse_permute(
388                        itertools::chain!(
389                            group_key.iter_mut(),
390                            aggregates.iter_mut().map(|a| &mut a.expr)
391                        ),
392                        columns_to_pushdown.iter(),
393                    );
394
395                    self.action(
396                        input,
397                        &columns_to_pushdown.into_iter().collect::<Vec<_>>(),
398                        gets,
399                    )?;
400                    let mut actual_projection =
401                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
402                    actual_projection.extend(0..group_key.len());
403                    actual_projection.into_iter().collect()
404                }
405                MirRelationExpr::TopK {
406                    input,
407                    group_key,
408                    order_key,
409                    limit,
410                    ..
411                } => {
412                    // Group and order keys and limit support must be retained, as
413                    // they define which rows are retained.
414                    let mut columns_to_pushdown =
415                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
416                    columns_to_pushdown.extend(group_key.iter().cloned());
417                    columns_to_pushdown.extend(order_key.iter().map(|o| o.column));
418                    if let Some(limit) = limit.as_ref() {
419                        // Strictly speaking not needed because the
420                        // `limit` support should be a subset of the
421                        // `group_key` support, but we don't want to
422                        // take this for granted here.
423                        limit.support_into(&mut columns_to_pushdown);
424                    }
425                    // If the `TopK` does not have any new column demand, just push
426                    // down the desired projection. Otherwise, push down the sorted
427                    // column demand.
428                    let columns_to_pushdown =
429                        if columns_to_pushdown.len() == desired_projection.len() {
430                            desired_projection.clone()
431                        } else {
432                            columns_to_pushdown.into_iter().collect::<Vec<_>>()
433                        };
434                    reverse_permute_columns(
435                        itertools::chain!(
436                            group_key.iter_mut(),
437                            order_key.iter_mut().map(|o| &mut o.column),
438                        ),
439                        columns_to_pushdown.iter(),
440                    );
441                    reverse_permute(limit.iter_mut(), columns_to_pushdown.iter());
442                    self.action(input, &columns_to_pushdown, gets)?;
443                    columns_to_pushdown
444                }
445                MirRelationExpr::Negate { input } => {
446                    self.action(input, desired_projection, gets)?;
447                    desired_projection.clone()
448                }
449                MirRelationExpr::Union { base, inputs } => {
450                    self.action(base, desired_projection, gets)?;
451                    for input in inputs {
452                        self.action(input, desired_projection, gets)?;
453                    }
454                    desired_projection.clone()
455                }
456                MirRelationExpr::Threshold { input } => {
457                    // Threshold requires all columns, as collapsing any distinct values
458                    // has the potential to change how it thresholds counts. This could
459                    // be improved with reasoning about distinctness or non-negativity.
460                    let arity = input.arity();
461                    self.action(input, &(0..arity).collect(), gets)?;
462                    (0..arity).collect()
463                }
464                MirRelationExpr::ArrangeBy { input, keys: _ } => {
465                    // Do not push the project past the ArrangeBy.
466                    // TODO: how do we handle key sets containing column references
467                    // that are not demanded upstream?
468                    let arity = input.arity();
469                    self.action(input, &(0..arity).collect(), gets)?;
470                    (0..arity).collect()
471                }
472            };
473            let add_project = desired_projection != &actual_projection;
474            if add_project {
475                let mut projection_to_add = desired_projection.to_owned();
476                reverse_permute_columns(projection_to_add.iter_mut(), actual_projection.iter());
477                *relation = relation.take_dangerous().project(projection_to_add);
478            }
479            Ok(())
480        })
481    }
482
483    /// When we push the `desired_value_projection` at `value`,
484    /// the columns returned by `Get(get_id)` will change, so we need
485    /// to permute `Project`s around `Get(get_id)`.
486    pub fn update_projection_around_get(
487        &self,
488        relation: &mut MirRelationExpr,
489        applied_projections: &BTreeMap<Id, (Vec<usize>, ReprRelationType)>,
490    ) {
491        relation.visit_pre_mut(|e| {
492            if let MirRelationExpr::Project { input, outputs } = e {
493                if let MirRelationExpr::Get {
494                    id: inner_id,
495                    typ,
496                    access_strategy: _,
497                } = &mut **input
498                {
499                    if let Some((new_projection, new_type)) = applied_projections.get(inner_id) {
500                        typ.clone_from(new_type);
501                        reverse_permute_columns(outputs.iter_mut(), new_projection.iter());
502                        if outputs.len() == new_projection.len()
503                            && outputs.iter().enumerate().all(|(i, o)| i == *o)
504                        {
505                            *e = input.take_dangerous();
506                        }
507                    }
508                }
509            }
510            // If there is no `Project` around a Get, all columns of
511            // `Get(get_id)` are required. Thus, the columns returned by
512            // `Get(get_id)` will not have changed, so no action
513            // is necessary.
514        });
515    }
516}
517
518/// Applies the reverse of [MirScalarExpr.permute] on each expression.
519///
520/// `permutation` can be thought of as a mapping of column references from
521/// `stateA` to `stateB`. [MirScalarExpr.permute] assumes that the column
522/// references of the expression are in `stateA` and need to be remapped to
523/// their `stateB` counterparts. This methods assumes that the column
524/// references are in `stateB` and need to be remapped to `stateA`.
525///
526/// The `outputs` field of [MirRelationExpr::Project] is a mapping from "after"
527/// to "before". Thus, when lifting projections, you would permute on `outputs`,
528/// but you need to reverse permute when pushing projections down.
529fn reverse_permute<'a, I, J>(exprs: I, permutation: J)
530where
531    I: Iterator<Item = &'a mut MirScalarExpr>,
532    J: Iterator<Item = &'a usize>,
533{
534    let reverse_col_map = permutation
535        .enumerate()
536        .map(|(idx, c)| (*c, idx))
537        .collect::<BTreeMap<_, _>>();
538    for expr in exprs {
539        expr.permute_map(&reverse_col_map);
540    }
541}
542
543/// Same as [reverse_permute], but takes column numbers as input
544fn reverse_permute_columns<'a, I, J>(columns: I, permutation: J)
545where
546    I: Iterator<Item = &'a mut usize>,
547    J: Iterator<Item = &'a usize>,
548{
549    let reverse_col_map = permutation
550        .enumerate()
551        .map(|(idx, c)| (*c, idx))
552        .collect::<BTreeMap<_, _>>();
553    for c in columns {
554        *c = reverse_col_map[c];
555    }
556}