mz_transform/movement/
projection_lifting.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Hoist projections through operators.
11//!
12//! Projections can be re-introduced in the physical planning stage.
13
14use std::collections::BTreeMap;
15use std::mem;
16
17use itertools::zip_eq;
18use mz_expr::{AccessStrategy, Id, MirRelationExpr, RECURSION_LIMIT};
19use mz_ore::stack::{CheckedRecursion, RecursionGuard};
20
21use crate::TransformCtx;
22
23/// Hoist projections through operators.
24#[derive(Debug)]
25pub struct ProjectionLifting {
26    recursion_guard: RecursionGuard,
27}
28
29impl Default for ProjectionLifting {
30    fn default() -> ProjectionLifting {
31        ProjectionLifting {
32            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
33        }
34    }
35}
36
37impl CheckedRecursion for ProjectionLifting {
38    fn recursion_guard(&self) -> &RecursionGuard {
39        &self.recursion_guard
40    }
41}
42
43impl crate::Transform for ProjectionLifting {
44    fn name(&self) -> &'static str {
45        "ProjectionLifting"
46    }
47
48    #[mz_ore::instrument(
49        target = "optimizer",
50        level = "debug",
51        fields(path.segment = "projection_lifting")
52    )]
53    fn actually_perform_transform(
54        &self,
55        relation: &mut MirRelationExpr,
56        _: &mut TransformCtx,
57    ) -> Result<(), crate::TransformError> {
58        let result = self.action(relation, &mut BTreeMap::new());
59        mz_repr::explain::trace_plan(&*relation);
60        result
61    }
62}
63
64impl ProjectionLifting {
65    /// Hoist projections through operators.
66    pub fn action(
67        &self,
68        relation: &mut MirRelationExpr,
69        // Map from names to new get type and projection required at use.
70        gets: &mut BTreeMap<Id, (mz_repr::RelationType, Vec<usize>)>,
71    ) -> Result<(), crate::TransformError> {
72        self.checked_recur(|_| {
73            match relation {
74                MirRelationExpr::Constant { .. } => Ok(()),
75                MirRelationExpr::Get {
76                    id,
77                    typ: _,
78                    access_strategy: _,
79                } => {
80                    if let Some((typ, columns)) = gets.get(id) {
81                        *relation = MirRelationExpr::Get {
82                            id: *id,
83                            typ: typ.clone(),
84                            access_strategy: AccessStrategy::UnknownOrLocal, // (we are not copying it over)
85                        }
86                        .project(columns.clone());
87                    }
88                    Ok(())
89                }
90                MirRelationExpr::Let { id, value, body } => {
91                    self.action(value, gets)?;
92                    let id = Id::Local(*id);
93                    if let MirRelationExpr::Project { input, outputs } = &mut **value {
94                        let typ = input.typ();
95                        let prior = gets.insert(id, (typ, outputs.clone()));
96                        assert!(!prior.is_some());
97                        **value = input.take_dangerous();
98                    }
99
100                    self.action(body, gets)?;
101                    gets.remove(&id);
102                    Ok(())
103                }
104                MirRelationExpr::LetRec {
105                    ids,
106                    values,
107                    limits: _,
108                    body,
109                } => {
110                    let recursive_ids = MirRelationExpr::recursive_ids(ids, values);
111
112                    for (local_id, value) in zip_eq(ids.iter(), values.iter_mut()) {
113                        self.action(value, gets)?;
114                        if !recursive_ids.contains(local_id) {
115                            if let MirRelationExpr::Project { input, outputs } = value {
116                                let id = Id::Local(*local_id);
117                                let typ = input.typ();
118                                let prior = gets.insert(id, (typ, outputs.clone()));
119                                assert!(!prior.is_some());
120                                *value = input.take_dangerous();
121                            }
122                        }
123                    }
124
125                    self.action(body, gets)?;
126
127                    for local_id in ids.iter().filter(|id| !recursive_ids.contains(id)) {
128                        gets.remove(&Id::Local(*local_id));
129                    }
130
131                    Ok(())
132                }
133                MirRelationExpr::Project { input, outputs } => {
134                    self.action(input, gets)?;
135                    if let MirRelationExpr::Project {
136                        input: inner,
137                        outputs: inner_outputs,
138                    } = &mut **input
139                    {
140                        for output in outputs.iter_mut() {
141                            *output = inner_outputs[*output];
142                        }
143                        **input = inner.take_dangerous();
144                    }
145                    Ok(())
146                }
147                MirRelationExpr::Map { input, scalars } => {
148                    self.action(input, gets)?;
149                    if let MirRelationExpr::Project {
150                        input: inner,
151                        outputs,
152                    } = &mut **input
153                    {
154                        // Retain projected columns and scalar columns.
155                        let mut new_outputs = outputs.clone();
156                        let inner_arity = inner.arity();
157                        new_outputs.extend(inner_arity..(inner_arity + scalars.len()));
158
159                        // Rewrite scalar expressions using inner columns.
160                        for scalar in scalars.iter_mut() {
161                            scalar.permute(&new_outputs);
162                        }
163
164                        *relation = inner
165                            .take_dangerous()
166                            .map(scalars.clone())
167                            .project(new_outputs);
168                    }
169                    Ok(())
170                }
171                MirRelationExpr::FlatMap { input, func, exprs } => {
172                    self.action(input, gets)?;
173                    if let MirRelationExpr::Project {
174                        input: inner,
175                        outputs,
176                    } = &mut **input
177                    {
178                        // Retain projected columns and scalar columns.
179                        let mut new_outputs = outputs.clone();
180                        let inner_arity = inner.arity();
181                        new_outputs.extend(inner_arity..(inner_arity + func.output_arity()));
182
183                        // Rewrite scalar expression using inner columns.
184                        for expr in exprs.iter_mut() {
185                            expr.permute(&new_outputs);
186                        }
187
188                        *relation = inner
189                            .take_dangerous()
190                            .flat_map(func.clone(), exprs.clone())
191                            .project(new_outputs);
192                    }
193                    Ok(())
194                }
195                MirRelationExpr::Filter { input, predicates } => {
196                    self.action(input, gets)?;
197                    if let MirRelationExpr::Project {
198                        input: inner,
199                        outputs,
200                    } = &mut **input
201                    {
202                        // Rewrite scalar expressions using inner columns.
203                        for predicate in predicates.iter_mut() {
204                            predicate.permute(outputs);
205                        }
206                        *relation = inner
207                            .take_dangerous()
208                            .filter(predicates.clone())
209                            .project(outputs.clone());
210                    }
211                    Ok(())
212                }
213                MirRelationExpr::Join {
214                    inputs,
215                    equivalences,
216                    implementation,
217                } => {
218                    for input in inputs.iter_mut() {
219                        self.action(input, gets)?;
220                    }
221
222                    // Track the location of the projected columns in the un-projected join.
223                    let mut projection = Vec::new();
224                    let mut temp_arity = 0;
225
226                    for join_input in inputs.iter_mut() {
227                        if let MirRelationExpr::Project { input, outputs } = join_input {
228                            for output in outputs.iter() {
229                                projection.push(temp_arity + *output);
230                            }
231                            temp_arity += input.arity();
232                            *join_input = input.take_dangerous();
233                        } else {
234                            let arity = join_input.arity();
235                            projection.extend(temp_arity..(temp_arity + arity));
236                            temp_arity += arity;
237                        }
238                    }
239
240                    // Don't add the identity permutation as a projection.
241                    if projection.len() != temp_arity || (0..temp_arity).any(|i| projection[i] != i)
242                    {
243                        // Update equivalences and implementation.
244                        for equivalence in equivalences.iter_mut() {
245                            for expr in equivalence {
246                                expr.permute(&projection[..]);
247                            }
248                        }
249
250                        *implementation = mz_expr::JoinImplementation::Unimplemented;
251
252                        *relation = relation.take_dangerous().project(projection);
253                    }
254                    Ok(())
255                }
256                MirRelationExpr::Reduce {
257                    input,
258                    group_key,
259                    aggregates,
260                    monotonic: _,
261                    expected_group_size: _,
262                } => {
263                    // Reduce *absorbs* projections, which is amazing!
264                    self.action(input, gets)?;
265                    if let MirRelationExpr::Project {
266                        input: inner,
267                        outputs,
268                    } = &mut **input
269                    {
270                        for key in group_key.iter_mut() {
271                            key.permute(outputs);
272                        }
273                        for aggregate in aggregates.iter_mut() {
274                            aggregate.expr.permute(outputs);
275                        }
276                        **input = inner.take_dangerous();
277                    }
278                    Ok(())
279                }
280                MirRelationExpr::TopK {
281                    input,
282                    group_key,
283                    order_key,
284                    limit,
285                    offset,
286                    monotonic: _,
287                    expected_group_size,
288                } => {
289                    self.action(input, gets)?;
290                    if let MirRelationExpr::Project {
291                        input: inner,
292                        outputs,
293                    } = &mut **input
294                    {
295                        for key in group_key.iter_mut() {
296                            *key = outputs[*key];
297                        }
298                        for key in order_key.iter_mut() {
299                            key.column = outputs[key.column];
300                        }
301                        if let Some(limit) = limit.as_mut() {
302                            limit.permute(outputs);
303                        }
304                        *relation = inner
305                            .take_dangerous()
306                            .top_k(
307                                group_key.clone(),
308                                order_key.clone(),
309                                limit.clone(),
310                                offset.clone(),
311                                expected_group_size.clone(),
312                            )
313                            .project(outputs.clone());
314                    }
315                    Ok(())
316                }
317                MirRelationExpr::Negate { input } => {
318                    self.action(input, gets)?;
319                    if let MirRelationExpr::Project {
320                        input: inner,
321                        outputs,
322                    } = &mut **input
323                    {
324                        *relation = inner.take_dangerous().negate().project(outputs.clone());
325                    }
326                    Ok(())
327                }
328                MirRelationExpr::Threshold { input } => {
329                    // We cannot, in general, lift projections out of threshold.
330                    // If we could reason that the input cannot be negative, we
331                    // would be able to lift the projection, but otherwise our
332                    // action on weights need to accumulate the restricted rows.
333                    self.action(input, gets)
334                }
335                MirRelationExpr::Union { base, inputs } => {
336                    // We cannot, in general, lift projections out of unions.
337                    self.action(base, gets)?;
338                    for input in &mut *inputs {
339                        self.action(input, gets)?;
340                    }
341
342                    if let MirRelationExpr::Project {
343                        input: base_input,
344                        outputs: base_outputs,
345                    } = &mut **base
346                    {
347                        let base_typ = base_input.typ();
348
349                        let mut can_lift = true;
350                        for input in &mut *inputs {
351                            match input {
352                                MirRelationExpr::Project { input, outputs }
353                                    if input.typ() == base_typ && outputs == base_outputs => {}
354                                _ => {
355                                    can_lift = false;
356                                    break;
357                                }
358                            }
359                        }
360
361                        if can_lift {
362                            let base_outputs = mem::take(base_outputs);
363                            **base = base_input.take_dangerous();
364                            for inp in inputs {
365                                match inp {
366                                    MirRelationExpr::Project { input, .. } => {
367                                        *inp = input.take_dangerous();
368                                    }
369                                    _ => unreachable!(),
370                                }
371                            }
372                            *relation = relation.take_dangerous().project(base_outputs);
373                        }
374                    }
375                    Ok(())
376                }
377                MirRelationExpr::ArrangeBy { input, keys } => {
378                    self.action(input, gets)?;
379                    if let MirRelationExpr::Project {
380                        input: inner,
381                        outputs,
382                    } = &mut **input
383                    {
384                        for key_set in keys.iter_mut() {
385                            for key in key_set.iter_mut() {
386                                key.permute(outputs);
387                            }
388                        }
389                        *relation = inner
390                            .take_dangerous()
391                            .arrange_by(keys)
392                            .project(outputs.clone());
393                    }
394                    Ok(())
395                }
396            }
397        })
398    }
399}