1use std::collections::BTreeMap;
15use std::mem;
16
17use itertools::zip_eq;
18use mz_expr::{AccessStrategy, Id, MirRelationExpr, RECURSION_LIMIT};
19use mz_ore::stack::{CheckedRecursion, RecursionGuard};
20use mz_repr::ReprRelationType;
21
22use crate::TransformCtx;
23
24#[derive(Debug)]
26pub struct ProjectionLifting {
27 recursion_guard: RecursionGuard,
28}
29
30impl Default for ProjectionLifting {
31 fn default() -> ProjectionLifting {
32 ProjectionLifting {
33 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
34 }
35 }
36}
37
38impl CheckedRecursion for ProjectionLifting {
39 fn recursion_guard(&self) -> &RecursionGuard {
40 &self.recursion_guard
41 }
42}
43
44impl crate::Transform for ProjectionLifting {
45 fn name(&self) -> &'static str {
46 "ProjectionLifting"
47 }
48
49 #[mz_ore::instrument(
50 target = "optimizer",
51 level = "debug",
52 fields(path.segment = "projection_lifting")
53 )]
54 fn actually_perform_transform(
55 &self,
56 relation: &mut MirRelationExpr,
57 _: &mut TransformCtx,
58 ) -> Result<(), crate::TransformError> {
59 let result = self.action(relation, &mut BTreeMap::new());
60 mz_repr::explain::trace_plan(&*relation);
61 result
62 }
63}
64
65impl ProjectionLifting {
66 pub fn action(
68 &self,
69 relation: &mut MirRelationExpr,
70 gets: &mut BTreeMap<Id, (ReprRelationType, Vec<usize>)>,
72 ) -> Result<(), crate::TransformError> {
73 self.checked_recur(|_| {
74 match relation {
75 MirRelationExpr::Constant { .. } => Ok(()),
76 MirRelationExpr::Get {
77 id,
78 typ: _,
79 access_strategy: _,
80 } => {
81 if let Some((typ, columns)) = gets.get(id) {
82 *relation = MirRelationExpr::Get {
83 id: *id,
84 typ: typ.clone(),
85 access_strategy: AccessStrategy::UnknownOrLocal, }
87 .project(columns.clone());
88 }
89 Ok(())
90 }
91 MirRelationExpr::Let { id, value, body } => {
92 self.action(value, gets)?;
93 let id = Id::Local(*id);
94 if let MirRelationExpr::Project { input, outputs } = &mut **value {
95 let typ = input.typ();
96 let prior = gets.insert(id, (typ, outputs.clone()));
97 assert!(!prior.is_some());
98 **value = input.take_dangerous();
99 }
100
101 self.action(body, gets)?;
102 gets.remove(&id);
103 Ok(())
104 }
105 MirRelationExpr::LetRec {
106 ids,
107 values,
108 limits: _,
109 body,
110 } => {
111 let recursive_ids = MirRelationExpr::recursive_ids(ids, values);
112
113 for (local_id, value) in zip_eq(ids.iter(), values.iter_mut()) {
114 self.action(value, gets)?;
115 if !recursive_ids.contains(local_id) {
116 if let MirRelationExpr::Project { input, outputs } = value {
117 let id = Id::Local(*local_id);
118 let typ = input.typ();
119 let prior = gets.insert(id, (typ, outputs.clone()));
120 assert!(!prior.is_some());
121 *value = input.take_dangerous();
122 }
123 }
124 }
125
126 self.action(body, gets)?;
127
128 for local_id in ids.iter().filter(|id| !recursive_ids.contains(id)) {
129 gets.remove(&Id::Local(*local_id));
130 }
131
132 Ok(())
133 }
134 MirRelationExpr::Project { input, outputs } => {
135 self.action(input, gets)?;
136 if let MirRelationExpr::Project {
137 input: inner,
138 outputs: inner_outputs,
139 } = &mut **input
140 {
141 for output in outputs.iter_mut() {
142 *output = inner_outputs[*output];
143 }
144 **input = inner.take_dangerous();
145 }
146 Ok(())
147 }
148 MirRelationExpr::Map { input, scalars } => {
149 self.action(input, gets)?;
150 if let MirRelationExpr::Project {
151 input: inner,
152 outputs,
153 } = &mut **input
154 {
155 let mut new_outputs = outputs.clone();
157 let inner_arity = inner.arity();
158 new_outputs.extend(inner_arity..(inner_arity + scalars.len()));
159
160 for scalar in scalars.iter_mut() {
162 scalar.permute(&new_outputs);
163 }
164
165 *relation = inner
166 .take_dangerous()
167 .map(scalars.clone())
168 .project(new_outputs);
169 }
170 Ok(())
171 }
172 MirRelationExpr::FlatMap { input, func, exprs } => {
173 self.action(input, gets)?;
174 if let MirRelationExpr::Project {
175 input: inner,
176 outputs,
177 } = &mut **input
178 {
179 let mut new_outputs = outputs.clone();
181 let inner_arity = inner.arity();
182 new_outputs.extend(inner_arity..(inner_arity + func.output_arity()));
183
184 for expr in exprs.iter_mut() {
186 expr.permute(&new_outputs);
187 }
188
189 *relation = inner
190 .take_dangerous()
191 .flat_map(func.clone(), exprs.clone())
192 .project(new_outputs);
193 }
194 Ok(())
195 }
196 MirRelationExpr::Filter { input, predicates } => {
197 self.action(input, gets)?;
198 if let MirRelationExpr::Project {
199 input: inner,
200 outputs,
201 } = &mut **input
202 {
203 for predicate in predicates.iter_mut() {
205 predicate.permute(outputs);
206 }
207 *relation = inner
208 .take_dangerous()
209 .filter(predicates.clone())
210 .project(outputs.clone());
211 }
212 Ok(())
213 }
214 MirRelationExpr::Join {
215 inputs,
216 equivalences,
217 implementation,
218 } => {
219 for input in inputs.iter_mut() {
220 self.action(input, gets)?;
221 }
222
223 let mut projection = Vec::new();
225 let mut temp_arity = 0;
226
227 for join_input in inputs.iter_mut() {
228 if let MirRelationExpr::Project { input, outputs } = join_input {
229 for output in outputs.iter() {
230 projection.push(temp_arity + *output);
231 }
232 temp_arity += input.arity();
233 *join_input = input.take_dangerous();
234 } else {
235 let arity = join_input.arity();
236 projection.extend(temp_arity..(temp_arity + arity));
237 temp_arity += arity;
238 }
239 }
240
241 if projection.len() != temp_arity || (0..temp_arity).any(|i| projection[i] != i)
243 {
244 for equivalence in equivalences.iter_mut() {
246 for expr in equivalence {
247 expr.permute(&projection[..]);
248 }
249 }
250
251 *implementation = mz_expr::JoinImplementation::Unimplemented;
252
253 *relation = relation.take_dangerous().project(projection);
254 }
255 Ok(())
256 }
257 MirRelationExpr::Reduce {
258 input,
259 group_key,
260 aggregates,
261 monotonic: _,
262 expected_group_size: _,
263 } => {
264 self.action(input, gets)?;
266 if let MirRelationExpr::Project {
267 input: inner,
268 outputs,
269 } = &mut **input
270 {
271 for key in group_key.iter_mut() {
272 key.permute(outputs);
273 }
274 for aggregate in aggregates.iter_mut() {
275 aggregate.expr.permute(outputs);
276 }
277 **input = inner.take_dangerous();
278 }
279 Ok(())
280 }
281 MirRelationExpr::TopK {
282 input,
283 group_key,
284 order_key,
285 limit,
286 offset,
287 monotonic: _,
288 expected_group_size,
289 } => {
290 self.action(input, gets)?;
291 if let MirRelationExpr::Project {
292 input: inner,
293 outputs,
294 } = &mut **input
295 {
296 for key in group_key.iter_mut() {
297 *key = outputs[*key];
298 }
299 for key in order_key.iter_mut() {
300 key.column = outputs[key.column];
301 }
302 if let Some(limit) = limit.as_mut() {
303 limit.permute(outputs);
304 }
305 *relation = inner
306 .take_dangerous()
307 .top_k(
308 group_key.clone(),
309 order_key.clone(),
310 limit.clone(),
311 offset.clone(),
312 expected_group_size.clone(),
313 )
314 .project(outputs.clone());
315 }
316 Ok(())
317 }
318 MirRelationExpr::Negate { input } => {
319 self.action(input, gets)?;
320 if let MirRelationExpr::Project {
321 input: inner,
322 outputs,
323 } = &mut **input
324 {
325 *relation = inner.take_dangerous().negate().project(outputs.clone());
326 }
327 Ok(())
328 }
329 MirRelationExpr::Threshold { input } => {
330 self.action(input, gets)
335 }
336 MirRelationExpr::Union { base, inputs } => {
337 self.action(base, gets)?;
339 for input in &mut *inputs {
340 self.action(input, gets)?;
341 }
342
343 if let MirRelationExpr::Project {
344 input: base_input,
345 outputs: base_outputs,
346 } = &mut **base
347 {
348 let base_typ = base_input.typ();
349
350 let mut can_lift = true;
351 for input in &mut *inputs {
352 match input {
353 MirRelationExpr::Project { input, outputs }
354 if input.typ() == base_typ && outputs == base_outputs => {}
355 _ => {
356 can_lift = false;
357 break;
358 }
359 }
360 }
361
362 if can_lift {
363 let base_outputs = mem::take(base_outputs);
364 **base = base_input.take_dangerous();
365 for inp in inputs {
366 match inp {
367 MirRelationExpr::Project { input, .. } => {
368 *inp = input.take_dangerous();
369 }
370 _ => unreachable!(),
371 }
372 }
373 *relation = relation.take_dangerous().project(base_outputs);
374 }
375 }
376 Ok(())
377 }
378 MirRelationExpr::ArrangeBy { input, keys } => {
379 self.action(input, gets)?;
380 if let MirRelationExpr::Project {
381 input: inner,
382 outputs,
383 } = &mut **input
384 {
385 for key_set in keys.iter_mut() {
386 for key in key_set.iter_mut() {
387 key.permute(outputs);
388 }
389 }
390 *relation = inner
391 .take_dangerous()
392 .arrange_by(keys)
393 .project(outputs.clone());
394 }
395 Ok(())
396 }
397 }
398 })
399 }
400}