mz_transform/movement/
projection_pushdown.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes column removal down through other operators.
11//!
12//! This action improves the quality of the query by
13//! reducing the width of data in the dataflow. It determines the unique
14//! columns an expression depends on, and pushes a projection onto only
15//! those columns down through child operators.
16//!
17//! A `MirRelationExpr::Project` node is actually three transformations in one.
18//! 1) Projection - removes columns.
19//! 2) Permutation - reorders columns.
20//! 3) Repetition - duplicates columns.
21//!
22//! This action handles these three transformations like so:
23//! 1) Projections are pushed as far down as possible.
24//! 2) Permutations are pushed as far down as is convenient.
25//! 3) Repetitions are not pushed down at all.
26//!
27//! Some comments have been inherited from the `Demand` transform.
28//!
29//! Note that this transform is one that can operate across views in a dataflow
30//! and thus currently exists outside of both the physical and logical
31//! optimizers.
32
33use std::collections::{BTreeMap, BTreeSet};
34
35use itertools::zip_eq;
36use mz_expr::{
37    Id, JoinImplementation, JoinInputMapper, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
38};
39use mz_ore::assert_none;
40use mz_ore::stack::{CheckedRecursion, RecursionGuard};
41
42use crate::{TransformCtx, TransformError};
43
44/// Pushes projections down through other operators.
45#[derive(Debug)]
46pub struct ProjectionPushdown {
47    recursion_guard: RecursionGuard,
48    include_joins: bool,
49}
50
51impl Default for ProjectionPushdown {
52    fn default() -> Self {
53        Self {
54            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
55            include_joins: true,
56        }
57    }
58}
59
60impl ProjectionPushdown {
61    /// Construct a `ProjectionPushdown` that does not push projections through joins (but does
62    /// descend into join inputs).
63    pub fn skip_joins() -> Self {
64        Self {
65            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
66            include_joins: false,
67        }
68    }
69}
70
71impl CheckedRecursion for ProjectionPushdown {
72    fn recursion_guard(&self) -> &RecursionGuard {
73        &self.recursion_guard
74    }
75}
76
77impl crate::Transform for ProjectionPushdown {
78    fn name(&self) -> &'static str {
79        "ProjectionPushdown"
80    }
81
82    // This method is only used during unit testing.
83    #[mz_ore::instrument(
84        target = "optimizer",
85        level = "debug",
86        fields(path.segment = "projection_pushdown")
87    )]
88    fn actually_perform_transform(
89        &self,
90        relation: &mut MirRelationExpr,
91        _: &mut TransformCtx,
92    ) -> Result<(), crate::TransformError> {
93        let result = self.action(
94            relation,
95            &(0..relation.arity()).collect(),
96            &mut BTreeMap::new(),
97        );
98        mz_repr::explain::trace_plan(&*relation);
99        result
100    }
101}
102
103impl ProjectionPushdown {
104    /// Pushes the `desired_projection` down through `relation`.
105    ///
106    /// This action transforms `relation` to a `MirRelationExpr` equivalent to
107    /// `relation.project(desired_projection)`.
108    ///
109    /// `desired_projection` is expected to consist of unique columns.
110    pub fn action(
111        &self,
112        relation: &mut MirRelationExpr,
113        desired_projection: &Vec<usize>,
114        gets: &mut BTreeMap<Id, BTreeSet<usize>>,
115    ) -> Result<(), TransformError> {
116        self.checked_recur(|_| {
117            // First, try to push the desired projection down through `relation`.
118            // In the process `relation` is transformed to a `MirRelationExpr`
119            // equivalent to `relation.project(actual_projection)`.
120            // There are three reasons why `actual_projection` may differ from
121            // `desired_projection`:
122            // 1) `relation` may need one or more columns that is not contained in
123            //    `desired_projection`.
124            // 2) `relation` may not be able to accommodate certain permutations.
125            //    For example, `MirRelationExpr::Map` always appends all
126            //    newly-created columns to the end.
127            // 3) Nothing can be pushed through a leaf node. If `relation` is a leaf
128            //    node, `actual_projection` will always be `(0..relation.arity())`.
129            // Then, if `actual_projection` and `desired_projection` differ, we will
130            // add a project around `relation`.
131            let actual_projection = match relation {
132                MirRelationExpr::Constant { .. } => (0..relation.arity()).collect(),
133                MirRelationExpr::Get { id, .. } => {
134                    gets.entry(*id)
135                        .or_insert_with(BTreeSet::new)
136                        .extend(desired_projection.iter().cloned());
137                    (0..relation.arity()).collect()
138                }
139                MirRelationExpr::Let { id, value, body } => {
140                    // Let harvests any requirements of get from its body,
141                    // and pushes the sorted union of the requirements at its value.
142                    let id = Id::Local(*id);
143                    let prior = gets.insert(id, BTreeSet::new());
144                    self.action(body, desired_projection, gets)?;
145                    let desired_value_projection = gets.remove(&id).unwrap();
146                    if let Some(prior) = prior {
147                        gets.insert(id, prior);
148                    }
149                    let desired_value_projection =
150                        desired_value_projection.into_iter().collect::<Vec<_>>();
151                    self.action(value, &desired_value_projection, gets)?;
152                    let new_type = value.typ();
153                    self.update_projection_around_get(
154                        body,
155                        &BTreeMap::from_iter(std::iter::once((
156                            id,
157                            (desired_value_projection, new_type),
158                        ))),
159                    );
160                    desired_projection.clone()
161                }
162                MirRelationExpr::LetRec {
163                    ids,
164                    values,
165                    limits: _,
166                    body,
167                } => {
168                    // Determine the recursive IDs in this LetRec binding.
169                    let rec_ids = MirRelationExpr::recursive_ids(ids, values);
170
171                    // Seed the gets map with empty demand for each non-recursive ID.
172                    for id in ids.iter().filter(|id| !rec_ids.contains(id)) {
173                        let prior = gets.insert(Id::Local(*id), BTreeSet::new());
174                        assert_none!(prior);
175                    }
176
177                    // Descend into the body with the supplied desired_projection.
178                    self.action(body, desired_projection, gets)?;
179                    // Descend into the values in reverse order.
180                    for (id, value) in zip_eq(ids.iter().rev(), values.iter_mut().rev()) {
181                        let desired_projection = if rec_ids.contains(id) {
182                            // For recursive IDs: request all columns.
183                            let columns = 0..value.arity();
184                            columns.collect::<Vec<_>>()
185                        } else {
186                            // For non-recursive IDs: request the gets entry.
187                            let columns = gets.get(&Id::Local(*id)).unwrap();
188                            columns.iter().cloned().collect::<Vec<_>>()
189                        };
190                        self.action(value, &desired_projection, gets)?;
191                    }
192
193                    // Update projections around gets of non-recursive IDs.
194                    let mut updates = BTreeMap::new();
195                    for (id, value) in zip_eq(ids.iter(), values.iter_mut()) {
196                        // Update the current value.
197                        self.update_projection_around_get(value, &updates);
198                        // If this is a non-recursive ID, add an entry to the
199                        // updates map for subsequent values and the body.
200                        if !rec_ids.contains(id) {
201                            let new_type = value.typ();
202                            let new_proj = {
203                                let columns = gets.remove(&Id::Local(*id)).unwrap();
204                                columns.iter().cloned().collect::<Vec<_>>()
205                            };
206                            updates.insert(Id::Local(*id), (new_proj, new_type));
207                        }
208                    }
209                    // Update the body.
210                    self.update_projection_around_get(body, &updates);
211
212                    // Remove the entries for all ids (don't restrict only to
213                    // non-recursive IDs here for better hygene).
214                    for id in ids.iter() {
215                        gets.remove(&Id::Local(*id));
216                    }
217
218                    // Return the desired projection (leads to a no-op in the
219                    // projection handling logic after this match statement).
220                    desired_projection.clone()
221                }
222                MirRelationExpr::Join {
223                    inputs,
224                    equivalences,
225                    implementation,
226                } if self.include_joins => {
227                    assert!(
228                        matches!(implementation, JoinImplementation::Unimplemented),
229                        "ProjectionPushdown can't deal with filled in join implementations. Turn off `include_joins` if you'd like to run it after `JoinImplementation`."
230                    );
231
232                    let input_mapper = JoinInputMapper::new(inputs);
233
234                    let mut columns_to_pushdown =
235                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
236                    // Each equivalence class imposes internal demand for columns.
237                    for equivalence in equivalences.iter() {
238                        for expr in equivalence.iter() {
239                            expr.support_into(&mut columns_to_pushdown);
240                        }
241                    }
242
243                    // Populate child demands from external and internal demands.
244                    let new_columns =
245                        input_mapper.split_column_set_by_input(columns_to_pushdown.iter());
246
247                    // Recursively indicate the requirements.
248                    for (input, inp_columns) in inputs.iter_mut().zip(new_columns) {
249                        let inp_columns = inp_columns.into_iter().collect::<Vec<_>>();
250                        self.action(input, &inp_columns, gets)?;
251                    }
252
253                    reverse_permute(
254                        equivalences.iter_mut().flat_map(|e| e.iter_mut()),
255                        columns_to_pushdown.iter(),
256                    );
257
258                    columns_to_pushdown.into_iter().collect()
259                }
260                // Skip joins if `self.include_joins` is turned off.
261                MirRelationExpr::Join { inputs, equivalences: _, implementation: _ } => {
262                    let input_mapper = JoinInputMapper::new(inputs);
263
264                    // Include all columns.
265                    let columns_to_pushdown: Vec<_> = (0..input_mapper.total_columns()).collect();
266                    let child_columns =
267                        input_mapper.split_column_set_by_input(columns_to_pushdown.iter());
268
269                    // Recursively indicate the requirements.
270                    for (input, inp_columns) in inputs.iter_mut().zip(child_columns) {
271                        let inp_columns = inp_columns.into_iter().collect::<Vec<_>>();
272                        self.action(input, &inp_columns, gets)?;
273                    }
274
275                    columns_to_pushdown.into_iter().collect()
276                }
277                MirRelationExpr::FlatMap { input, func, exprs } => {
278                    let inner_arity = input.arity();
279                    // A FlatMap which returns zero rows acts like a filter
280                    // so we always need to execute it
281                    let mut columns_to_pushdown =
282                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
283                    for expr in exprs.iter() {
284                        expr.support_into(&mut columns_to_pushdown);
285                    }
286                    columns_to_pushdown.retain(|c| *c < inner_arity);
287
288                    reverse_permute(exprs.iter_mut(), columns_to_pushdown.iter());
289                    let columns_to_pushdown = columns_to_pushdown.into_iter().collect::<Vec<_>>();
290                    self.action(input, &columns_to_pushdown, gets)?;
291                    // The actual projection always has the newly-created columns at
292                    // the end.
293                    let mut actual_projection = columns_to_pushdown;
294                    for c in 0..func.output_type().arity() {
295                        actual_projection.push(inner_arity + c);
296                    }
297                    actual_projection
298                }
299                MirRelationExpr::Filter { input, predicates } => {
300                    let mut columns_to_pushdown =
301                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
302                    for predicate in predicates.iter() {
303                        predicate.support_into(&mut columns_to_pushdown);
304                    }
305                    reverse_permute(predicates.iter_mut(), columns_to_pushdown.iter());
306                    let columns_to_pushdown = columns_to_pushdown.into_iter().collect::<Vec<_>>();
307                    self.action(input, &columns_to_pushdown, gets)?;
308                    columns_to_pushdown
309                }
310                MirRelationExpr::Project { input, outputs } => {
311                    // Combine `outputs` with `desired_projection`.
312                    *outputs = desired_projection.iter().map(|c| outputs[*c]).collect();
313
314                    let unique_outputs = outputs.iter().map(|i| *i).collect::<BTreeSet<_>>();
315                    if outputs.len() == unique_outputs.len() {
316                        // Push down the project as is.
317                        self.action(input, outputs, gets)?;
318                        *relation = input.take_dangerous();
319                    } else {
320                        // Push down only the unique elems in `outputs`.
321                        let columns_to_pushdown = unique_outputs.into_iter().collect::<Vec<_>>();
322                        reverse_permute_columns(outputs.iter_mut(), columns_to_pushdown.iter());
323                        self.action(input, &columns_to_pushdown, gets)?;
324                    }
325
326                    desired_projection.clone()
327                }
328                MirRelationExpr::Map { input, scalars } => {
329                    let arity = input.arity();
330                    // contains columns whose supports have yet to be explored
331                    let mut actual_projection =
332                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
333                    for (i, scalar) in scalars.iter().enumerate().rev() {
334                        if actual_projection.contains(&(i + arity)) {
335                            scalar.support_into(&mut actual_projection);
336                        }
337                    }
338                    *scalars = (0..scalars.len())
339                        .filter_map(|i| {
340                            if actual_projection.contains(&(i + arity)) {
341                                Some(scalars[i].clone())
342                            } else {
343                                None
344                            }
345                        })
346                        .collect::<Vec<_>>();
347                    reverse_permute(scalars.iter_mut(), actual_projection.iter());
348                    self.action(
349                        input,
350                        &actual_projection
351                            .iter()
352                            .filter(|c| **c < arity)
353                            .map(|c| *c)
354                            .collect(),
355                        gets,
356                    )?;
357                    actual_projection.into_iter().collect()
358                }
359                MirRelationExpr::Reduce {
360                    input,
361                    group_key,
362                    aggregates,
363                    monotonic: _,
364                    expected_group_size: _,
365                } => {
366                    let mut columns_to_pushdown = BTreeSet::new();
367                    // Group keys determine aggregation granularity and are
368                    // each crucial in determining aggregates and even the
369                    // multiplicities of other keys.
370                    for k in group_key.iter() {
371                        k.support_into(&mut columns_to_pushdown)
372                    }
373
374                    for index in (0..aggregates.len()).rev() {
375                        if !desired_projection.contains(&(group_key.len() + index)) {
376                            aggregates.remove(index);
377                        } else {
378                            // No obvious requirements on aggregate columns.
379                            // A "non-empty" requirement, I guess?
380                            aggregates[index]
381                                .expr
382                                .support_into(&mut columns_to_pushdown)
383                        }
384                    }
385
386                    reverse_permute(
387                        itertools::chain!(
388                            group_key.iter_mut(),
389                            aggregates.iter_mut().map(|a| &mut a.expr)
390                        ),
391                        columns_to_pushdown.iter(),
392                    );
393
394                    self.action(
395                        input,
396                        &columns_to_pushdown.into_iter().collect::<Vec<_>>(),
397                        gets,
398                    )?;
399                    let mut actual_projection =
400                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
401                    actual_projection.extend(0..group_key.len());
402                    actual_projection.into_iter().collect()
403                }
404                MirRelationExpr::TopK {
405                    input,
406                    group_key,
407                    order_key,
408                    limit,
409                    ..
410                } => {
411                    // Group and order keys and limit support must be retained, as
412                    // they define which rows are retained.
413                    let mut columns_to_pushdown =
414                        desired_projection.iter().cloned().collect::<BTreeSet<_>>();
415                    columns_to_pushdown.extend(group_key.iter().cloned());
416                    columns_to_pushdown.extend(order_key.iter().map(|o| o.column));
417                    if let Some(limit) = limit.as_ref() {
418                        // Strictly speaking not needed because the
419                        // `limit` support should be a subset of the
420                        // `group_key` support, but we don't want to
421                        // take this for granted here.
422                        limit.support_into(&mut columns_to_pushdown);
423                    }
424                    // If the `TopK` does not have any new column demand, just push
425                    // down the desired projection. Otherwise, push down the sorted
426                    // column demand.
427                    let columns_to_pushdown =
428                        if columns_to_pushdown.len() == desired_projection.len() {
429                            desired_projection.clone()
430                        } else {
431                            columns_to_pushdown.into_iter().collect::<Vec<_>>()
432                        };
433                    reverse_permute_columns(
434                        itertools::chain!(
435                            group_key.iter_mut(),
436                            order_key.iter_mut().map(|o| &mut o.column),
437                        ),
438                        columns_to_pushdown.iter(),
439                    );
440                    reverse_permute(limit.iter_mut(), columns_to_pushdown.iter());
441                    self.action(input, &columns_to_pushdown, gets)?;
442                    columns_to_pushdown
443                }
444                MirRelationExpr::Negate { input } => {
445                    self.action(input, desired_projection, gets)?;
446                    desired_projection.clone()
447                }
448                MirRelationExpr::Union { base, inputs } => {
449                    self.action(base, desired_projection, gets)?;
450                    for input in inputs {
451                        self.action(input, desired_projection, gets)?;
452                    }
453                    desired_projection.clone()
454                }
455                MirRelationExpr::Threshold { input } => {
456                    // Threshold requires all columns, as collapsing any distinct values
457                    // has the potential to change how it thresholds counts. This could
458                    // be improved with reasoning about distinctness or non-negativity.
459                    let arity = input.arity();
460                    self.action(input, &(0..arity).collect(), gets)?;
461                    (0..arity).collect()
462                }
463                MirRelationExpr::ArrangeBy { input, keys: _ } => {
464                    // Do not push the project past the ArrangeBy.
465                    // TODO: how do we handle key sets containing column references
466                    // that are not demanded upstream?
467                    let arity = input.arity();
468                    self.action(input, &(0..arity).collect(), gets)?;
469                    (0..arity).collect()
470                }
471            };
472            let add_project = desired_projection != &actual_projection;
473            if add_project {
474                let mut projection_to_add = desired_projection.to_owned();
475                reverse_permute_columns(projection_to_add.iter_mut(), actual_projection.iter());
476                *relation = relation.take_dangerous().project(projection_to_add);
477            }
478            Ok(())
479        })
480    }
481
482    /// When we push the `desired_value_projection` at `value`,
483    /// the columns returned by `Get(get_id)` will change, so we need
484    /// to permute `Project`s around `Get(get_id)`.
485    pub fn update_projection_around_get(
486        &self,
487        relation: &mut MirRelationExpr,
488        applied_projections: &BTreeMap<Id, (Vec<usize>, mz_repr::RelationType)>,
489    ) {
490        relation.visit_pre_mut(|e| {
491            if let MirRelationExpr::Project { input, outputs } = e {
492                if let MirRelationExpr::Get {
493                    id: inner_id,
494                    typ,
495                    access_strategy: _,
496                } = &mut **input
497                {
498                    if let Some((new_projection, new_type)) = applied_projections.get(inner_id) {
499                        typ.clone_from(new_type);
500                        reverse_permute_columns(outputs.iter_mut(), new_projection.iter());
501                        if outputs.len() == new_projection.len()
502                            && outputs.iter().enumerate().all(|(i, o)| i == *o)
503                        {
504                            *e = input.take_dangerous();
505                        }
506                    }
507                }
508            }
509            // If there is no `Project` around a Get, all columns of
510            // `Get(get_id)` are required. Thus, the columns returned by
511            // `Get(get_id)` will not have changed, so no action
512            // is necessary.
513        });
514    }
515}
516
517/// Applies the reverse of [MirScalarExpr.permute] on each expression.
518///
519/// `permutation` can be thought of as a mapping of column references from
520/// `stateA` to `stateB`. [MirScalarExpr.permute] assumes that the column
521/// references of the expression are in `stateA` and need to be remapped to
522/// their `stateB` counterparts. This methods assumes that the column
523/// references are in `stateB` and need to be remapped to `stateA`.
524///
525/// The `outputs` field of [MirRelationExpr::Project] is a mapping from "after"
526/// to "before". Thus, when lifting projections, you would permute on `outputs`,
527/// but you need to reverse permute when pushing projections down.
528fn reverse_permute<'a, I, J>(exprs: I, permutation: J)
529where
530    I: Iterator<Item = &'a mut MirScalarExpr>,
531    J: Iterator<Item = &'a usize>,
532{
533    let reverse_col_map = permutation
534        .enumerate()
535        .map(|(idx, c)| (*c, idx))
536        .collect::<BTreeMap<_, _>>();
537    for expr in exprs {
538        expr.permute_map(&reverse_col_map);
539    }
540}
541
542/// Same as [reverse_permute], but takes column numbers as input
543fn reverse_permute_columns<'a, I, J>(columns: I, permutation: J)
544where
545    I: Iterator<Item = &'a mut usize>,
546    J: Iterator<Item = &'a usize>,
547{
548    let reverse_col_map = permutation
549        .enumerate()
550        .map(|(idx, c)| (*c, idx))
551        .collect::<BTreeMap<_, _>>();
552    for c in columns {
553        *c = reverse_col_map[c];
554    }
555}