Skip to main content

mz_transform/movement/
projection_pushdown.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes column removal down through other operators.
11//!
12//! This action improves the quality of the query by
13//! reducing the width of data in the dataflow. It determines the unique
14//! columns an expression depends on, and pushes a projection onto only
15//! those columns down through child operators.
16//!
17//! A `MirRelationExpr::Project` node is actually three transformations in one.
18//! 1) Projection - removes columns.
19//! 2) Permutation - reorders columns.
20//! 3) Repetition - duplicates columns.
21//!
22//! This action handles these three transformations like so:
23//! 1) Projections are pushed as far down as possible.
24//! 2) Permutations are pushed as far down as is convenient.
25//! 3) Repetitions are not pushed down at all.
26//!
27//! Some comments have been inherited from the `Demand` transform.
28//!
29//! Note that this transform is one that can operate across views in a dataflow
30//! and thus currently exists outside of both the physical and logical
31//! optimizers.
32
33use std::collections::{BTreeMap, BTreeSet};
34
35use itertools::{Itertools, zip_eq};
36use mz_expr::{
37    Columns, Id, JoinImplementation, JoinInputMapper, MirRelationExpr, MirScalarExpr,
38    RECURSION_LIMIT,
39};
40use mz_ore::assert_none;
41use mz_ore::stack::{CheckedRecursion, RecursionGuard};
42use mz_repr::ReprRelationType;
43
44use crate::{TransformCtx, TransformError};
45
46/// Pushes projections down through other operators.
47#[derive(Debug)]
48pub struct ProjectionPushdown {
49    recursion_guard: RecursionGuard,
50    include_joins: bool,
51}
52
53impl Default for ProjectionPushdown {
54    fn default() -> Self {
55        Self {
56            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
57            include_joins: true,
58        }
59    }
60}
61
62impl ProjectionPushdown {
63    /// Construct a `ProjectionPushdown` that does not push projections through joins (but does
64    /// descend into join inputs).
65    pub fn skip_joins() -> Self {
66        Self {
67            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
68            include_joins: false,
69        }
70    }
71}
72
73impl CheckedRecursion for ProjectionPushdown {
74    fn recursion_guard(&self) -> &RecursionGuard {
75        &self.recursion_guard
76    }
77}
78
79impl crate::Transform for ProjectionPushdown {
80    fn name(&self) -> &'static str {
81        "ProjectionPushdown"
82    }
83
84    // This method is only used during unit testing.
85    #[mz_ore::instrument(
86        target = "optimizer",
87        level = "debug",
88        fields(path.segment = "projection_pushdown")
89    )]
90    fn actually_perform_transform(
91        &self,
92        relation: &mut MirRelationExpr,
93        _: &mut TransformCtx,
94    ) -> Result<(), crate::TransformError> {
95        let result = self.action(
96            relation,
97            &(0..relation.arity()).collect(),
98            &mut BTreeMap::new(),
99        );
100        mz_repr::explain::trace_plan(&*relation);
101        result
102    }
103}
104
105impl ProjectionPushdown {
106    /// Pushes the `desired_projection` down through `relation`.
107    ///
108    /// This action transforms `relation` to a `MirRelationExpr` equivalent to
109    /// `relation.project(desired_projection)`.
110    ///
111    /// `desired_projection` is expected to consist of unique columns.
112    pub fn action(
113        &self,
114        relation: &mut MirRelationExpr,
115        desired_projection: &Vec<usize>,
116        gets: &mut BTreeMap<Id, BTreeSet<usize>>,
117    ) -> Result<(), TransformError> {
118        self.checked_recur(|_| {
119            // First, try to push the desired projection down through `relation`.
120            // In the process `relation` is transformed to a `MirRelationExpr`
121            // equivalent to `relation.project(actual_projection)`.
122            // There are three reasons why `actual_projection` may differ from
123            // `desired_projection`:
124            // 1) `relation` may need one or more columns that is not contained in
125            //    `desired_projection`.
126            // 2) `relation` may not be able to accommodate certain permutations.
127            //    For example, `MirRelationExpr::Map` always appends all
128            //    newly-created columns to the end.
129            // 3) Nothing can be pushed through a leaf node. If `relation` is a leaf
130            //    node, `actual_projection` will always be `(0..relation.arity())`.
131            // Then, if `actual_projection` and `desired_projection` differ, we will
132            // add a project around `relation`.
133            let actual_projection = match relation {
134                MirRelationExpr::Constant { .. } => (0..relation.arity()).collect(),
135                MirRelationExpr::Get { id, .. } => {
136                    gets.entry(*id)
137                        .or_insert_with(BTreeSet::new)
138                        .extend(desired_projection.iter().cloned());
139                    (0..relation.arity()).collect()
140                }
141                MirRelationExpr::Let { id, value, body } => {
142                    // Let harvests any requirements of get from its body,
143                    // and pushes the sorted union of the requirements at its value.
144                    let id = Id::Local(*id);
145                    let prior = gets.insert(id, BTreeSet::new());
146                    self.action(body, desired_projection, gets)?;
147                    let desired_value_projection = gets.remove(&id).unwrap();
148                    if let Some(prior) = prior {
149                        gets.insert(id, prior);
150                    }
151                    let desired_value_projection =
152                        desired_value_projection.into_iter().collect::<Vec<_>>();
153                    self.action(value, &desired_value_projection, gets)?;
154                    let new_type = value.typ();
155                    self.update_projection_around_get(
156                        body,
157                        &BTreeMap::from_iter(std::iter::once((
158                            id,
159                            (desired_value_projection, new_type),
160                        ))),
161                    );
162                    desired_projection.clone()
163                }
164                MirRelationExpr::LetRec {
165                    ids,
166                    values,
167                    limits: _,
168                    body,
169                } => {
170                    // Determine the recursive IDs in this LetRec binding.
171                    let rec_ids = MirRelationExpr::recursive_ids(ids, values);
172
173                    // Seed the gets map with empty demand for each non-recursive ID.
174                    for id in ids.iter().filter(|id| !rec_ids.contains(id)) {
175                        let prior = gets.insert(Id::Local(*id), BTreeSet::new());
176                        assert_none!(prior);
177                    }
178
179                    // Descend into the body with the supplied desired_projection.
180                    self.action(body, desired_projection, gets)?;
181                    // Descend into the values in reverse order.
182                    for (id, value) in zip_eq(ids.iter().rev(), values.iter_mut().rev()) {
183                        let desired_projection = if rec_ids.contains(id) {
184                            // For recursive IDs: request all columns.
185                            let columns = 0..value.arity();
186                            columns.collect::<Vec<_>>()
187                        } else {
188                            // For non-recursive IDs: request the gets entry.
189                            let columns = gets.get(&Id::Local(*id)).unwrap();
190                            columns.iter().cloned().collect::<Vec<_>>()
191                        };
192                        self.action(value, &desired_projection, gets)?;
193                    }
194
195                    // Update projections around gets of non-recursive IDs.
196                    let mut updates = BTreeMap::new();
197                    for (id, value) in zip_eq(ids.iter(), values.iter_mut()) {
198                        // Update the current value.
199                        self.update_projection_around_get(value, &updates);
200                        // If this is a non-recursive ID, add an entry to the
201                        // updates map for subsequent values and the body.
202                        if !rec_ids.contains(id) {
203                            let new_type = value.typ();
204                            let new_proj = {
205                                let columns = gets.remove(&Id::Local(*id)).unwrap();
206                                columns.iter().cloned().collect::<Vec<_>>()
207                            };
208                            updates.insert(Id::Local(*id), (new_proj, new_type));
209                        }
210                    }
211                    // Update the body.
212                    self.update_projection_around_get(body, &updates);
213
214                    // Remove the entries for all ids (don't restrict only to
215                    // non-recursive IDs here for better hygene).
216                    for id in ids.iter() {
217                        gets.remove(&Id::Local(*id));
218                    }
219
220                    // Return the desired projection (leads to a no-op in the
221                    // projection handling logic after this match statement).
222                    desired_projection.clone()
223                }
224                MirRelationExpr::Join {
225                    inputs,
226                    equivalences,
227                    implementation,
228                } if self.include_joins => {
229                    assert!(
230                        matches!(implementation, JoinImplementation::Unimplemented),
231                        "ProjectionPushdown can't deal with filled in join implementations. Turn off `include_joins` if you'd like to run it after `JoinImplementation`."
232                    );
233
234                    let input_mapper = JoinInputMapper::new(inputs);
235
236                    let mut columns_to_pushdown =
237                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
238                    // Each equivalence class imposes internal demand for columns.
239                    for equivalence in equivalences.iter() {
240                        for expr in equivalence.iter() {
241                            expr.support_into(&mut columns_to_pushdown);
242                        }
243                    }
244
245                    // Populate child demands from external and internal demands.
246                    let new_columns =
247                        input_mapper.split_column_set_by_input(columns_to_pushdown.iter());
248
249                    // Recursively indicate the requirements.
250                    for (input, inp_columns) in inputs.iter_mut().zip_eq(new_columns) {
251                        let inp_columns = inp_columns.into_iter().collect::<Vec<_>>();
252                        self.action(input, &inp_columns, gets)?;
253                    }
254
255                    reverse_permute(
256                        equivalences.iter_mut().flat_map(|e| e.iter_mut()),
257                        columns_to_pushdown.iter(),
258                    );
259
260                    columns_to_pushdown.into_iter().collect()
261                }
262                // Skip joins if `self.include_joins` is turned off.
263                MirRelationExpr::Join { inputs, equivalences: _, implementation: _ } => {
264                    let input_mapper = JoinInputMapper::new(inputs);
265
266                    // Include all columns.
267                    let columns_to_pushdown: Vec<_> = (0..input_mapper.total_columns()).collect();
268                    let child_columns =
269                        input_mapper.split_column_set_by_input(columns_to_pushdown.iter());
270
271                    // Recursively indicate the requirements.
272                    for (input, inp_columns) in inputs.iter_mut().zip_eq(child_columns) {
273                        let inp_columns = inp_columns.into_iter().collect::<Vec<_>>();
274                        self.action(input, &inp_columns, gets)?;
275                    }
276
277                    columns_to_pushdown.into_iter().collect()
278                }
279                MirRelationExpr::FlatMap { input, func, exprs } => {
280                    let inner_arity = input.arity();
281                    // A FlatMap which returns zero rows acts like a filter
282                    // so we always need to execute it
283                    let mut columns_to_pushdown =
284                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
285                    for expr in exprs.iter() {
286                        expr.support_into(&mut columns_to_pushdown);
287                    }
288                    columns_to_pushdown.retain(|c| *c < inner_arity);
289
290                    reverse_permute(exprs.iter_mut(), columns_to_pushdown.iter());
291                    let columns_to_pushdown = columns_to_pushdown.into_iter().collect::<Vec<_>>();
292                    self.action(input, &columns_to_pushdown, gets)?;
293                    // The actual projection always has the newly-created columns at
294                    // the end.
295                    let mut actual_projection = columns_to_pushdown;
296                    for c in 0..func.output_arity() {
297                        actual_projection.push(inner_arity + c);
298                    }
299                    actual_projection
300                }
301                MirRelationExpr::Filter { input, predicates } => {
302                    let mut columns_to_pushdown =
303                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
304                    for predicate in predicates.iter() {
305                        predicate.support_into(&mut columns_to_pushdown);
306                    }
307                    reverse_permute(predicates.iter_mut(), columns_to_pushdown.iter());
308                    let columns_to_pushdown = columns_to_pushdown.into_iter().collect::<Vec<_>>();
309                    self.action(input, &columns_to_pushdown, gets)?;
310                    columns_to_pushdown
311                }
312                MirRelationExpr::Project { input, outputs } => {
313                    // Combine `outputs` with `desired_projection`.
314                    *outputs = desired_projection.iter().map(|c| outputs[*c]).collect();
315
316                    let unique_outputs = outputs.iter().map(|i| *i).collect::<BTreeSet<_>>();
317                    if outputs.len() == unique_outputs.len() {
318                        // Push down the project as is.
319                        self.action(input, outputs, gets)?;
320                        *relation = input.take_dangerous();
321                    } else {
322                        // Push down only the unique elems in `outputs`.
323                        let columns_to_pushdown = unique_outputs.into_iter().collect::<Vec<_>>();
324                        reverse_permute_columns(outputs.iter_mut(), columns_to_pushdown.iter());
325                        self.action(input, &columns_to_pushdown, gets)?;
326                    }
327
328                    desired_projection.clone()
329                }
330                MirRelationExpr::Map { input, scalars } => {
331                    let arity = input.arity();
332                    // contains columns whose supports have yet to be explored
333                    let mut actual_projection =
334                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
335                    for (i, scalar) in scalars.iter().enumerate().rev() {
336                        if actual_projection.contains(&(i + arity)) {
337                            scalar.support_into(&mut actual_projection);
338                        }
339                    }
340                    *scalars = (0..scalars.len())
341                        .filter_map(|i| {
342                            if actual_projection.contains(&(i + arity)) {
343                                Some(scalars[i].clone())
344                            } else {
345                                None
346                            }
347                        })
348                        .collect::<Vec<_>>();
349                    reverse_permute(scalars.iter_mut(), actual_projection.iter());
350                    self.action(
351                        input,
352                        &actual_projection
353                            .iter()
354                            .filter(|c| **c < arity)
355                            .map(|c| *c)
356                            .collect(),
357                        gets,
358                    )?;
359                    actual_projection.into_iter().collect()
360                }
361                MirRelationExpr::Reduce {
362                    input,
363                    group_key,
364                    aggregates,
365                    monotonic: _,
366                    expected_group_size: _,
367                } => {
368                    let mut columns_to_pushdown = BTreeSet::new();
369                    // Group keys determine aggregation granularity and are
370                    // each crucial in determining aggregates and even the
371                    // multiplicities of other keys.
372                    for k in group_key.iter() {
373                        k.support_into(&mut columns_to_pushdown)
374                    }
375
376                    for index in (0..aggregates.len()).rev() {
377                        if !desired_projection.contains(&(group_key.len() + index)) {
378                            aggregates.remove(index);
379                        } else {
380                            // No obvious requirements on aggregate columns.
381                            // A "non-empty" requirement, I guess?
382                            aggregates[index]
383                                .expr
384                                .support_into(&mut columns_to_pushdown)
385                        }
386                    }
387
388                    reverse_permute(
389                        itertools::chain!(
390                            group_key.iter_mut(),
391                            aggregates.iter_mut().map(|a| &mut a.expr)
392                        ),
393                        columns_to_pushdown.iter(),
394                    );
395
396                    self.action(
397                        input,
398                        &columns_to_pushdown.into_iter().collect::<Vec<_>>(),
399                        gets,
400                    )?;
401                    let mut actual_projection =
402                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
403                    actual_projection.extend(0..group_key.len());
404                    actual_projection.into_iter().collect()
405                }
406                MirRelationExpr::TopK {
407                    input,
408                    group_key,
409                    order_key,
410                    limit,
411                    ..
412                } => {
413                    // Group and order keys and limit support must be retained, as
414                    // they define which rows are retained.
415                    let mut columns_to_pushdown =
416                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
417                    columns_to_pushdown.extend(group_key.iter().cloned());
418                    columns_to_pushdown.extend(order_key.iter().map(|o| o.column));
419                    if let Some(limit) = limit.as_ref() {
420                        // Strictly speaking not needed because the
421                        // `limit` support should be a subset of the
422                        // `group_key` support, but we don't want to
423                        // take this for granted here.
424                        limit.support_into(&mut columns_to_pushdown);
425                    }
426                    // If the `TopK` does not have any new column demand, just push
427                    // down the desired projection. Otherwise, push down the sorted
428                    // column demand.
429                    let columns_to_pushdown =
430                        if columns_to_pushdown.len() == desired_projection.len() {
431                            desired_projection.clone()
432                        } else {
433                            columns_to_pushdown.into_iter().collect::<Vec<_>>()
434                        };
435                    reverse_permute_columns(
436                        itertools::chain!(
437                            group_key.iter_mut(),
438                            order_key.iter_mut().map(|o| &mut o.column),
439                        ),
440                        columns_to_pushdown.iter(),
441                    );
442                    reverse_permute(limit.iter_mut(), columns_to_pushdown.iter());
443                    self.action(input, &columns_to_pushdown, gets)?;
444                    columns_to_pushdown
445                }
446                MirRelationExpr::Negate { input } => {
447                    self.action(input, desired_projection, gets)?;
448                    desired_projection.clone()
449                }
450                MirRelationExpr::Union { base, inputs } => {
451                    self.action(base, desired_projection, gets)?;
452                    for input in inputs {
453                        self.action(input, desired_projection, gets)?;
454                    }
455                    desired_projection.clone()
456                }
457                MirRelationExpr::Threshold { input } => {
458                    // Threshold requires all columns, as collapsing any distinct values
459                    // has the potential to change how it thresholds counts. This could
460                    // be improved with reasoning about distinctness or non-negativity.
461                    let arity = input.arity();
462                    self.action(input, &(0..arity).collect(), gets)?;
463                    (0..arity).collect()
464                }
465                MirRelationExpr::ArrangeBy { input, keys: _ } => {
466                    // Do not push the project past the ArrangeBy.
467                    // TODO: how do we handle key sets containing column references
468                    // that are not demanded upstream?
469                    let arity = input.arity();
470                    self.action(input, &(0..arity).collect(), gets)?;
471                    (0..arity).collect()
472                }
473            };
474            let add_project = desired_projection != &actual_projection;
475            if add_project {
476                let mut projection_to_add = desired_projection.to_owned();
477                reverse_permute_columns(projection_to_add.iter_mut(), actual_projection.iter());
478                *relation = relation.take_dangerous().project(projection_to_add);
479            }
480            Ok(())
481        })
482    }
483
484    /// When we push the `desired_value_projection` at `value`,
485    /// the columns returned by `Get(get_id)` will change, so we need
486    /// to permute `Project`s around `Get(get_id)`.
487    pub fn update_projection_around_get(
488        &self,
489        relation: &mut MirRelationExpr,
490        applied_projections: &BTreeMap<Id, (Vec<usize>, ReprRelationType)>,
491    ) {
492        relation.visit_pre_mut(|e| {
493            if let MirRelationExpr::Project { input, outputs } = e {
494                if let MirRelationExpr::Get {
495                    id: inner_id,
496                    typ,
497                    access_strategy: _,
498                } = &mut **input
499                {
500                    if let Some((new_projection, new_type)) = applied_projections.get(inner_id) {
501                        typ.clone_from(new_type);
502                        reverse_permute_columns(outputs.iter_mut(), new_projection.iter());
503                        if outputs.len() == new_projection.len()
504                            && outputs.iter().enumerate().all(|(i, o)| i == *o)
505                        {
506                            *e = input.take_dangerous();
507                        }
508                    }
509                }
510            }
511            // If there is no `Project` around a Get, all columns of
512            // `Get(get_id)` are required. Thus, the columns returned by
513            // `Get(get_id)` will not have changed, so no action
514            // is necessary.
515        });
516    }
517}
518
519/// Applies the reverse of [MirScalarExpr.permute] on each expression.
520///
521/// `permutation` can be thought of as a mapping of column references from
522/// `stateA` to `stateB`. [MirScalarExpr.permute] assumes that the column
523/// references of the expression are in `stateA` and need to be remapped to
524/// their `stateB` counterparts. This methods assumes that the column
525/// references are in `stateB` and need to be remapped to `stateA`.
526///
527/// The `outputs` field of [MirRelationExpr::Project] is a mapping from "after"
528/// to "before". Thus, when lifting projections, you would permute on `outputs`,
529/// but you need to reverse permute when pushing projections down.
530fn reverse_permute<'a, I, J>(exprs: I, permutation: J)
531where
532    I: Iterator<Item = &'a mut MirScalarExpr>,
533    J: Iterator<Item = &'a usize>,
534{
535    let reverse_col_map = permutation
536        .enumerate()
537        .map(|(idx, c)| (*c, idx))
538        .collect::<BTreeMap<_, _>>();
539    for expr in exprs {
540        expr.permute_map(&reverse_col_map);
541    }
542}
543
544/// Same as [reverse_permute], but takes column numbers as input
545fn reverse_permute_columns<'a, I, J>(columns: I, permutation: J)
546where
547    I: Iterator<Item = &'a mut usize>,
548    J: Iterator<Item = &'a usize>,
549{
550    let reverse_col_map = permutation
551        .enumerate()
552        .map(|(idx, c)| (*c, idx))
553        .collect::<BTreeMap<_, _>>();
554    for c in columns {
555        *c = reverse_col_map[c];
556    }
557}