mz_transform/canonicalization/
projection_extraction.rs1use mz_expr::visit::Visit;
13use mz_expr::{MirRelationExpr, MirScalarExpr};
14
15use crate::TransformCtx;
16
17#[derive(Debug)]
20pub struct ProjectionExtraction;
21
22impl crate::Transform for ProjectionExtraction {
23 fn name(&self) -> &'static str {
24 "ProjectionExtraction"
25 }
26
27 #[mz_ore::instrument(
28 target = "optimizer",
29 level = "debug",
30 fields(path.segment = "projection_extraction")
31 )]
32 fn actually_perform_transform(
33 &self,
34 relation: &mut MirRelationExpr,
35 _: &mut TransformCtx,
36 ) -> Result<(), crate::TransformError> {
37 relation.visit_mut_post(&mut Self::action)?;
38 mz_repr::explain::trace_plan(&*relation);
39 Ok(())
40 }
41}
42
43impl ProjectionExtraction {
44 pub fn action(relation: &mut MirRelationExpr) {
46 if let MirRelationExpr::Map { input, scalars } = relation {
47 if scalars
48 .iter()
49 .any(|s| matches!(s, MirScalarExpr::Column(_, _)))
50 {
51 let input_arity = input.arity();
52 let mut outputs: Vec<_> = (0..input_arity).collect();
53 let mut dropped = 0;
54 scalars.retain(|scalar| {
55 if let MirScalarExpr::Column(col, _) = scalar {
56 dropped += 1;
57 outputs.push(outputs[*col]);
60 false } else {
62 outputs.push(outputs.len() - dropped);
63 true }
65 });
66 if dropped > 0 {
67 for scalar in scalars {
68 scalar.permute(&outputs);
69 }
70 *relation = relation.take_dangerous().project(outputs);
71 }
72 }
73 } else if let MirRelationExpr::Reduce {
74 input: _,
75 group_key,
76 aggregates,
77 monotonic: _,
78 expected_group_size: _,
79 } = relation
80 {
81 let mut projection = Vec::new();
82
83 let mut finger = 0;
85 while finger < group_key.len() {
86 if let Some(position) = group_key[..finger]
87 .iter()
88 .position(|x| x == &group_key[finger])
89 {
90 projection.push(position);
91 group_key.remove(finger);
92 } else {
93 projection.push(finger);
94 finger += 1;
95 }
96 }
97
98 let mut finger = 0;
101 while finger < aggregates.len() {
102 if let Some(position) = aggregates[..finger]
103 .iter()
104 .position(|x| x == &aggregates[finger])
105 {
106 projection.push(group_key.len() + position);
107 aggregates.remove(finger);
108 } else {
109 projection.push(group_key.len() + finger);
110 finger += 1;
111 }
112 }
113 if projection.iter().enumerate().any(|(i, p)| i != *p) {
114 *relation = relation.take_dangerous().project(projection);
115 }
116 }
117 }
118}