mz_transform/canonicalization/
projection_extraction.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transform column references in a `Map` into a `Project`.
11
12use mz_expr::visit::Visit;
13use mz_expr::{MirRelationExpr, MirScalarExpr};
14
15use crate::TransformCtx;
16
17/// Transform column references in a `Map` into a `Project`, or repeated
18/// aggregations in a `Reduce` into a `Project`.
19#[derive(Debug)]
20pub struct ProjectionExtraction;
21
22impl crate::Transform for ProjectionExtraction {
23    fn name(&self) -> &'static str {
24        "ProjectionExtraction"
25    }
26
27    #[mz_ore::instrument(
28        target = "optimizer",
29        level = "debug",
30        fields(path.segment = "projection_extraction")
31    )]
32    fn actually_perform_transform(
33        &self,
34        relation: &mut MirRelationExpr,
35        _: &mut TransformCtx,
36    ) -> Result<(), crate::TransformError> {
37        relation.visit_mut_post(&mut Self::action)?;
38        mz_repr::explain::trace_plan(&*relation);
39        Ok(())
40    }
41}
42
43impl ProjectionExtraction {
44    /// Transform column references in a `Map` into a `Project`.
45    pub fn action(relation: &mut MirRelationExpr) {
46        if let MirRelationExpr::Map { input, scalars } = relation {
47            if scalars
48                .iter()
49                .any(|s| matches!(s, MirScalarExpr::Column(_, _)))
50            {
51                let input_arity = input.arity();
52                let mut outputs: Vec<_> = (0..input_arity).collect();
53                let mut dropped = 0;
54                scalars.retain(|scalar| {
55                    if let MirScalarExpr::Column(col, _) = scalar {
56                        dropped += 1;
57                        // We may need to chase down a few levels of indirection;
58                        // find the original input column in `outputs[*col]`.
59                        outputs.push(outputs[*col]);
60                        false // don't retain
61                    } else {
62                        outputs.push(outputs.len() - dropped);
63                        true // retain
64                    }
65                });
66                if dropped > 0 {
67                    for scalar in scalars {
68                        scalar.permute(&outputs);
69                    }
70                    *relation = relation.take_dangerous().project(outputs);
71                }
72            }
73        } else if let MirRelationExpr::Reduce {
74            input: _,
75            group_key,
76            aggregates,
77            monotonic: _,
78            expected_group_size: _,
79        } = relation
80        {
81            let mut projection = Vec::new();
82
83            // If any key is an exact duplicate, we can remove it and use a projection.
84            let mut finger = 0;
85            while finger < group_key.len() {
86                if let Some(position) = group_key[..finger]
87                    .iter()
88                    .position(|x| x == &group_key[finger])
89                {
90                    projection.push(position);
91                    group_key.remove(finger);
92                } else {
93                    projection.push(finger);
94                    finger += 1;
95                }
96            }
97
98            // If any entry of aggregates exists earlier in aggregates, we can remove it
99            // and replace it with a projection that points to the first instance of it.
100            let mut finger = 0;
101            while finger < aggregates.len() {
102                if let Some(position) = aggregates[..finger]
103                    .iter()
104                    .position(|x| x == &aggregates[finger])
105                {
106                    projection.push(group_key.len() + position);
107                    aggregates.remove(finger);
108                } else {
109                    projection.push(group_key.len() + finger);
110                    finger += 1;
111                }
112            }
113            if projection.iter().enumerate().any(|(i, p)| i != *p) {
114                *relation = relation.take_dangerous().project(projection);
115            }
116        }
117    }
118}