Skip to main content

mz_transform/movement/
projection_lifting.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Hoist projections through operators.
11//!
12//! Projections can be re-introduced in the physical planning stage.
13
14use std::collections::BTreeMap;
15use std::mem;
16
17use itertools::zip_eq;
18use mz_expr::{AccessStrategy, Id, MirRelationExpr, RECURSION_LIMIT};
19use mz_ore::stack::{CheckedRecursion, RecursionGuard};
20use mz_repr::ReprRelationType;
21
22use crate::TransformCtx;
23
24/// Hoist projections through operators.
25#[derive(Debug)]
26pub struct ProjectionLifting {
27    recursion_guard: RecursionGuard,
28}
29
30impl Default for ProjectionLifting {
31    fn default() -> ProjectionLifting {
32        ProjectionLifting {
33            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
34        }
35    }
36}
37
38impl CheckedRecursion for ProjectionLifting {
39    fn recursion_guard(&self) -> &RecursionGuard {
40        &self.recursion_guard
41    }
42}
43
44impl crate::Transform for ProjectionLifting {
45    fn name(&self) -> &'static str {
46        "ProjectionLifting"
47    }
48
49    #[mz_ore::instrument(
50        target = "optimizer",
51        level = "debug",
52        fields(path.segment = "projection_lifting")
53    )]
54    fn actually_perform_transform(
55        &self,
56        relation: &mut MirRelationExpr,
57        _: &mut TransformCtx,
58    ) -> Result<(), crate::TransformError> {
59        let result = self.action(relation, &mut BTreeMap::new());
60        mz_repr::explain::trace_plan(&*relation);
61        result
62    }
63}
64
65impl ProjectionLifting {
66    /// Hoist projections through operators.
67    pub fn action(
68        &self,
69        relation: &mut MirRelationExpr,
70        // Map from names to new get type and projection required at use.
71        gets: &mut BTreeMap<Id, (ReprRelationType, Vec<usize>)>,
72    ) -> Result<(), crate::TransformError> {
73        self.checked_recur(|_| {
74            match relation {
75                MirRelationExpr::Constant { .. } => Ok(()),
76                MirRelationExpr::Get {
77                    id,
78                    typ: _,
79                    access_strategy: _,
80                } => {
81                    if let Some((typ, columns)) = gets.get(id) {
82                        *relation = MirRelationExpr::Get {
83                            id: *id,
84                            typ: typ.clone(),
85                            access_strategy: AccessStrategy::UnknownOrLocal, // (we are not copying it over)
86                        }
87                        .project(columns.clone());
88                    }
89                    Ok(())
90                }
91                MirRelationExpr::Let { id, value, body } => {
92                    self.action(value, gets)?;
93                    let id = Id::Local(*id);
94                    if let MirRelationExpr::Project { input, outputs } = &mut **value {
95                        let typ = input.typ();
96                        let prior = gets.insert(id, (typ, outputs.clone()));
97                        assert!(!prior.is_some());
98                        **value = input.take_dangerous();
99                    }
100
101                    self.action(body, gets)?;
102                    gets.remove(&id);
103                    Ok(())
104                }
105                MirRelationExpr::LetRec {
106                    ids,
107                    values,
108                    limits: _,
109                    body,
110                } => {
111                    let recursive_ids = MirRelationExpr::recursive_ids(ids, values);
112
113                    for (local_id, value) in zip_eq(ids.iter(), values.iter_mut()) {
114                        self.action(value, gets)?;
115                        if !recursive_ids.contains(local_id) {
116                            if let MirRelationExpr::Project { input, outputs } = value {
117                                let id = Id::Local(*local_id);
118                                let typ = input.typ();
119                                let prior = gets.insert(id, (typ, outputs.clone()));
120                                assert!(!prior.is_some());
121                                *value = input.take_dangerous();
122                            }
123                        }
124                    }
125
126                    self.action(body, gets)?;
127
128                    for local_id in ids.iter().filter(|id| !recursive_ids.contains(id)) {
129                        gets.remove(&Id::Local(*local_id));
130                    }
131
132                    Ok(())
133                }
134                MirRelationExpr::Project { input, outputs } => {
135                    self.action(input, gets)?;
136                    if let MirRelationExpr::Project {
137                        input: inner,
138                        outputs: inner_outputs,
139                    } = &mut **input
140                    {
141                        for output in outputs.iter_mut() {
142                            *output = inner_outputs[*output];
143                        }
144                        **input = inner.take_dangerous();
145                    }
146                    Ok(())
147                }
148                MirRelationExpr::Map { input, scalars } => {
149                    self.action(input, gets)?;
150                    if let MirRelationExpr::Project {
151                        input: inner,
152                        outputs,
153                    } = &mut **input
154                    {
155                        // Retain projected columns and scalar columns.
156                        let mut new_outputs = outputs.clone();
157                        let inner_arity = inner.arity();
158                        new_outputs.extend(inner_arity..(inner_arity + scalars.len()));
159
160                        // Rewrite scalar expressions using inner columns.
161                        for scalar in scalars.iter_mut() {
162                            scalar.permute(&new_outputs);
163                        }
164
165                        *relation = inner
166                            .take_dangerous()
167                            .map(scalars.clone())
168                            .project(new_outputs);
169                    }
170                    Ok(())
171                }
172                MirRelationExpr::FlatMap { input, func, exprs } => {
173                    self.action(input, gets)?;
174                    if let MirRelationExpr::Project {
175                        input: inner,
176                        outputs,
177                    } = &mut **input
178                    {
179                        // Retain projected columns and scalar columns.
180                        let mut new_outputs = outputs.clone();
181                        let inner_arity = inner.arity();
182                        new_outputs.extend(inner_arity..(inner_arity + func.output_arity()));
183
184                        // Rewrite scalar expression using inner columns.
185                        for expr in exprs.iter_mut() {
186                            expr.permute(&new_outputs);
187                        }
188
189                        *relation = inner
190                            .take_dangerous()
191                            .flat_map(func.clone(), exprs.clone())
192                            .project(new_outputs);
193                    }
194                    Ok(())
195                }
196                MirRelationExpr::Filter { input, predicates } => {
197                    self.action(input, gets)?;
198                    if let MirRelationExpr::Project {
199                        input: inner,
200                        outputs,
201                    } = &mut **input
202                    {
203                        // Rewrite scalar expressions using inner columns.
204                        for predicate in predicates.iter_mut() {
205                            predicate.permute(outputs);
206                        }
207                        *relation = inner
208                            .take_dangerous()
209                            .filter(predicates.clone())
210                            .project(outputs.clone());
211                    }
212                    Ok(())
213                }
214                MirRelationExpr::Join {
215                    inputs,
216                    equivalences,
217                    implementation,
218                } => {
219                    for input in inputs.iter_mut() {
220                        self.action(input, gets)?;
221                    }
222
223                    // Track the location of the projected columns in the un-projected join.
224                    let mut projection = Vec::new();
225                    let mut temp_arity = 0;
226
227                    for join_input in inputs.iter_mut() {
228                        if let MirRelationExpr::Project { input, outputs } = join_input {
229                            for output in outputs.iter() {
230                                projection.push(temp_arity + *output);
231                            }
232                            temp_arity += input.arity();
233                            *join_input = input.take_dangerous();
234                        } else {
235                            let arity = join_input.arity();
236                            projection.extend(temp_arity..(temp_arity + arity));
237                            temp_arity += arity;
238                        }
239                    }
240
241                    // Don't add the identity permutation as a projection.
242                    if projection.len() != temp_arity || (0..temp_arity).any(|i| projection[i] != i)
243                    {
244                        // Update equivalences and implementation.
245                        for equivalence in equivalences.iter_mut() {
246                            for expr in equivalence {
247                                expr.permute(&projection[..]);
248                            }
249                        }
250
251                        *implementation = mz_expr::JoinImplementation::Unimplemented;
252
253                        *relation = relation.take_dangerous().project(projection);
254                    }
255                    Ok(())
256                }
257                MirRelationExpr::Reduce {
258                    input,
259                    group_key,
260                    aggregates,
261                    monotonic: _,
262                    expected_group_size: _,
263                } => {
264                    // Reduce *absorbs* projections, which is amazing!
265                    self.action(input, gets)?;
266                    if let MirRelationExpr::Project {
267                        input: inner,
268                        outputs,
269                    } = &mut **input
270                    {
271                        for key in group_key.iter_mut() {
272                            key.permute(outputs);
273                        }
274                        for aggregate in aggregates.iter_mut() {
275                            aggregate.expr.permute(outputs);
276                        }
277                        **input = inner.take_dangerous();
278                    }
279                    Ok(())
280                }
281                MirRelationExpr::TopK {
282                    input,
283                    group_key,
284                    order_key,
285                    limit,
286                    offset,
287                    monotonic: _,
288                    expected_group_size,
289                } => {
290                    self.action(input, gets)?;
291                    if let MirRelationExpr::Project {
292                        input: inner,
293                        outputs,
294                    } = &mut **input
295                    {
296                        for key in group_key.iter_mut() {
297                            *key = outputs[*key];
298                        }
299                        for key in order_key.iter_mut() {
300                            key.column = outputs[key.column];
301                        }
302                        if let Some(limit) = limit.as_mut() {
303                            limit.permute(outputs);
304                        }
305                        *relation = inner
306                            .take_dangerous()
307                            .top_k(
308                                group_key.clone(),
309                                order_key.clone(),
310                                limit.clone(),
311                                offset.clone(),
312                                expected_group_size.clone(),
313                            )
314                            .project(outputs.clone());
315                    }
316                    Ok(())
317                }
318                MirRelationExpr::Negate { input } => {
319                    self.action(input, gets)?;
320                    if let MirRelationExpr::Project {
321                        input: inner,
322                        outputs,
323                    } = &mut **input
324                    {
325                        *relation = inner.take_dangerous().negate().project(outputs.clone());
326                    }
327                    Ok(())
328                }
329                MirRelationExpr::Threshold { input } => {
330                    // We cannot, in general, lift projections out of threshold.
331                    // If we could reason that the input cannot be negative, we
332                    // would be able to lift the projection, but otherwise our
333                    // action on weights need to accumulate the restricted rows.
334                    self.action(input, gets)
335                }
336                MirRelationExpr::Union { base, inputs } => {
337                    // We cannot, in general, lift projections out of unions.
338                    self.action(base, gets)?;
339                    for input in &mut *inputs {
340                        self.action(input, gets)?;
341                    }
342
343                    if let MirRelationExpr::Project {
344                        input: base_input,
345                        outputs: base_outputs,
346                    } = &mut **base
347                    {
348                        let base_typ = base_input.typ();
349
350                        let mut can_lift = true;
351                        for input in &mut *inputs {
352                            match input {
353                                MirRelationExpr::Project { input, outputs }
354                                    if input.typ() == base_typ && outputs == base_outputs => {}
355                                _ => {
356                                    can_lift = false;
357                                    break;
358                                }
359                            }
360                        }
361
362                        if can_lift {
363                            let base_outputs = mem::take(base_outputs);
364                            **base = base_input.take_dangerous();
365                            for inp in inputs {
366                                match inp {
367                                    MirRelationExpr::Project { input, .. } => {
368                                        *inp = input.take_dangerous();
369                                    }
370                                    _ => unreachable!(),
371                                }
372                            }
373                            *relation = relation.take_dangerous().project(base_outputs);
374                        }
375                    }
376                    Ok(())
377                }
378                MirRelationExpr::ArrangeBy { input, keys } => {
379                    self.action(input, gets)?;
380                    if let MirRelationExpr::Project {
381                        input: inner,
382                        outputs,
383                    } = &mut **input
384                    {
385                        for key_set in keys.iter_mut() {
386                            for key in key_set.iter_mut() {
387                                key.permute(outputs);
388                            }
389                        }
390                        *relation = inner
391                            .take_dangerous()
392                            .arrange_by(keys)
393                            .project(outputs.clone());
394                    }
395                    Ok(())
396                }
397            }
398        })
399    }
400}