1use std::collections::BTreeMap;
15use std::mem;
16
17use itertools::zip_eq;
18use mz_expr::{AccessStrategy, Id, MirRelationExpr, RECURSION_LIMIT};
19use mz_ore::stack::{CheckedRecursion, RecursionGuard};
20
21use crate::TransformCtx;
22
23#[derive(Debug)]
25pub struct ProjectionLifting {
26 recursion_guard: RecursionGuard,
27}
28
29impl Default for ProjectionLifting {
30 fn default() -> ProjectionLifting {
31 ProjectionLifting {
32 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
33 }
34 }
35}
36
37impl CheckedRecursion for ProjectionLifting {
38 fn recursion_guard(&self) -> &RecursionGuard {
39 &self.recursion_guard
40 }
41}
42
43impl crate::Transform for ProjectionLifting {
44 fn name(&self) -> &'static str {
45 "ProjectionLifting"
46 }
47
48 #[mz_ore::instrument(
49 target = "optimizer",
50 level = "debug",
51 fields(path.segment = "projection_lifting")
52 )]
53 fn actually_perform_transform(
54 &self,
55 relation: &mut MirRelationExpr,
56 _: &mut TransformCtx,
57 ) -> Result<(), crate::TransformError> {
58 let result = self.action(relation, &mut BTreeMap::new());
59 mz_repr::explain::trace_plan(&*relation);
60 result
61 }
62}
63
64impl ProjectionLifting {
65 pub fn action(
67 &self,
68 relation: &mut MirRelationExpr,
69 gets: &mut BTreeMap<Id, (mz_repr::RelationType, Vec<usize>)>,
71 ) -> Result<(), crate::TransformError> {
72 self.checked_recur(|_| {
73 match relation {
74 MirRelationExpr::Constant { .. } => Ok(()),
75 MirRelationExpr::Get {
76 id,
77 typ: _,
78 access_strategy: _,
79 } => {
80 if let Some((typ, columns)) = gets.get(id) {
81 *relation = MirRelationExpr::Get {
82 id: *id,
83 typ: typ.clone(),
84 access_strategy: AccessStrategy::UnknownOrLocal, }
86 .project(columns.clone());
87 }
88 Ok(())
89 }
90 MirRelationExpr::Let { id, value, body } => {
91 self.action(value, gets)?;
92 let id = Id::Local(*id);
93 if let MirRelationExpr::Project { input, outputs } = &mut **value {
94 let typ = input.typ();
95 let prior = gets.insert(id, (typ, outputs.clone()));
96 assert!(!prior.is_some());
97 **value = input.take_dangerous();
98 }
99
100 self.action(body, gets)?;
101 gets.remove(&id);
102 Ok(())
103 }
104 MirRelationExpr::LetRec {
105 ids,
106 values,
107 limits: _,
108 body,
109 } => {
110 let recursive_ids = MirRelationExpr::recursive_ids(ids, values);
111
112 for (local_id, value) in zip_eq(ids.iter(), values.iter_mut()) {
113 self.action(value, gets)?;
114 if !recursive_ids.contains(local_id) {
115 if let MirRelationExpr::Project { input, outputs } = value {
116 let id = Id::Local(*local_id);
117 let typ = input.typ();
118 let prior = gets.insert(id, (typ, outputs.clone()));
119 assert!(!prior.is_some());
120 *value = input.take_dangerous();
121 }
122 }
123 }
124
125 self.action(body, gets)?;
126
127 for local_id in ids.iter().filter(|id| !recursive_ids.contains(id)) {
128 gets.remove(&Id::Local(*local_id));
129 }
130
131 Ok(())
132 }
133 MirRelationExpr::Project { input, outputs } => {
134 self.action(input, gets)?;
135 if let MirRelationExpr::Project {
136 input: inner,
137 outputs: inner_outputs,
138 } = &mut **input
139 {
140 for output in outputs.iter_mut() {
141 *output = inner_outputs[*output];
142 }
143 **input = inner.take_dangerous();
144 }
145 Ok(())
146 }
147 MirRelationExpr::Map { input, scalars } => {
148 self.action(input, gets)?;
149 if let MirRelationExpr::Project {
150 input: inner,
151 outputs,
152 } = &mut **input
153 {
154 let mut new_outputs = outputs.clone();
156 let inner_arity = inner.arity();
157 new_outputs.extend(inner_arity..(inner_arity + scalars.len()));
158
159 for scalar in scalars.iter_mut() {
161 scalar.permute(&new_outputs);
162 }
163
164 *relation = inner
165 .take_dangerous()
166 .map(scalars.clone())
167 .project(new_outputs);
168 }
169 Ok(())
170 }
171 MirRelationExpr::FlatMap { input, func, exprs } => {
172 self.action(input, gets)?;
173 if let MirRelationExpr::Project {
174 input: inner,
175 outputs,
176 } = &mut **input
177 {
178 let mut new_outputs = outputs.clone();
180 let inner_arity = inner.arity();
181 new_outputs.extend(inner_arity..(inner_arity + func.output_arity()));
182
183 for expr in exprs.iter_mut() {
185 expr.permute(&new_outputs);
186 }
187
188 *relation = inner
189 .take_dangerous()
190 .flat_map(func.clone(), exprs.clone())
191 .project(new_outputs);
192 }
193 Ok(())
194 }
195 MirRelationExpr::Filter { input, predicates } => {
196 self.action(input, gets)?;
197 if let MirRelationExpr::Project {
198 input: inner,
199 outputs,
200 } = &mut **input
201 {
202 for predicate in predicates.iter_mut() {
204 predicate.permute(outputs);
205 }
206 *relation = inner
207 .take_dangerous()
208 .filter(predicates.clone())
209 .project(outputs.clone());
210 }
211 Ok(())
212 }
213 MirRelationExpr::Join {
214 inputs,
215 equivalences,
216 implementation,
217 } => {
218 for input in inputs.iter_mut() {
219 self.action(input, gets)?;
220 }
221
222 let mut projection = Vec::new();
224 let mut temp_arity = 0;
225
226 for join_input in inputs.iter_mut() {
227 if let MirRelationExpr::Project { input, outputs } = join_input {
228 for output in outputs.iter() {
229 projection.push(temp_arity + *output);
230 }
231 temp_arity += input.arity();
232 *join_input = input.take_dangerous();
233 } else {
234 let arity = join_input.arity();
235 projection.extend(temp_arity..(temp_arity + arity));
236 temp_arity += arity;
237 }
238 }
239
240 if projection.len() != temp_arity || (0..temp_arity).any(|i| projection[i] != i)
242 {
243 for equivalence in equivalences.iter_mut() {
245 for expr in equivalence {
246 expr.permute(&projection[..]);
247 }
248 }
249
250 *implementation = mz_expr::JoinImplementation::Unimplemented;
251
252 *relation = relation.take_dangerous().project(projection);
253 }
254 Ok(())
255 }
256 MirRelationExpr::Reduce {
257 input,
258 group_key,
259 aggregates,
260 monotonic: _,
261 expected_group_size: _,
262 } => {
263 self.action(input, gets)?;
265 if let MirRelationExpr::Project {
266 input: inner,
267 outputs,
268 } = &mut **input
269 {
270 for key in group_key.iter_mut() {
271 key.permute(outputs);
272 }
273 for aggregate in aggregates.iter_mut() {
274 aggregate.expr.permute(outputs);
275 }
276 **input = inner.take_dangerous();
277 }
278 Ok(())
279 }
280 MirRelationExpr::TopK {
281 input,
282 group_key,
283 order_key,
284 limit,
285 offset,
286 monotonic: _,
287 expected_group_size,
288 } => {
289 self.action(input, gets)?;
290 if let MirRelationExpr::Project {
291 input: inner,
292 outputs,
293 } = &mut **input
294 {
295 for key in group_key.iter_mut() {
296 *key = outputs[*key];
297 }
298 for key in order_key.iter_mut() {
299 key.column = outputs[key.column];
300 }
301 if let Some(limit) = limit.as_mut() {
302 limit.permute(outputs);
303 }
304 *relation = inner
305 .take_dangerous()
306 .top_k(
307 group_key.clone(),
308 order_key.clone(),
309 limit.clone(),
310 offset.clone(),
311 expected_group_size.clone(),
312 )
313 .project(outputs.clone());
314 }
315 Ok(())
316 }
317 MirRelationExpr::Negate { input } => {
318 self.action(input, gets)?;
319 if let MirRelationExpr::Project {
320 input: inner,
321 outputs,
322 } = &mut **input
323 {
324 *relation = inner.take_dangerous().negate().project(outputs.clone());
325 }
326 Ok(())
327 }
328 MirRelationExpr::Threshold { input } => {
329 self.action(input, gets)
334 }
335 MirRelationExpr::Union { base, inputs } => {
336 self.action(base, gets)?;
338 for input in &mut *inputs {
339 self.action(input, gets)?;
340 }
341
342 if let MirRelationExpr::Project {
343 input: base_input,
344 outputs: base_outputs,
345 } = &mut **base
346 {
347 let base_typ = base_input.typ();
348
349 let mut can_lift = true;
350 for input in &mut *inputs {
351 match input {
352 MirRelationExpr::Project { input, outputs }
353 if input.typ() == base_typ && outputs == base_outputs => {}
354 _ => {
355 can_lift = false;
356 break;
357 }
358 }
359 }
360
361 if can_lift {
362 let base_outputs = mem::take(base_outputs);
363 **base = base_input.take_dangerous();
364 for inp in inputs {
365 match inp {
366 MirRelationExpr::Project { input, .. } => {
367 *inp = input.take_dangerous();
368 }
369 _ => unreachable!(),
370 }
371 }
372 *relation = relation.take_dangerous().project(base_outputs);
373 }
374 }
375 Ok(())
376 }
377 MirRelationExpr::ArrangeBy { input, keys } => {
378 self.action(input, gets)?;
379 if let MirRelationExpr::Project {
380 input: inner,
381 outputs,
382 } = &mut **input
383 {
384 for key_set in keys.iter_mut() {
385 for key in key_set.iter_mut() {
386 key.permute(outputs);
387 }
388 }
389 *relation = inner
390 .take_dangerous()
391 .arrange_by(keys)
392 .project(outputs.clone());
393 }
394 Ok(())
395 }
396 }
397 })
398 }
399}