Skip to main content

mz_sql/plan/
transform_hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::func::variadic::RecordCreate;
19use mz_expr::visit::Visit;
20use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
21use mz_ore::stack::RecursionLimitError;
22use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
23
24use crate::plan::hir::{
25    AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
26    ValueWindowExpr, ValueWindowFunc, WindowExpr,
27};
28use crate::plan::{AggregateExpr, WindowExprType};
29
30/// Rewrites predicates that contain subqueries so that the subqueries
31/// appear in their own later predicate when possible.
32///
33/// For example, this function rewrites this expression
34///
35/// ```text
36/// Filter {
37///     predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
38/// }
39/// ```
40///
41/// like so:
42///
43/// ```text
44/// Filter {
45///     predicates: [
46///         a = b AND c = d,
47///         EXISTS (<subquery>),
48///         (<subquery 2>) = e,
49///     ]
50/// }
51/// ```
52///
53/// The rewrite causes decorrelation to incorporate prior predicates into
54/// the outer relation upon which the subquery is evaluated. In the above
55/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
56/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
57/// = e`, will be further restricted to outer rows that match `A = b AND c =
58/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
59/// subquery, especially when the original conjunction contains join keys.
60pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61    fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
62        #[allow(deprecated)]
63        expr.visit_mut_fallible(0, &mut |expr, _| {
64            match expr {
65                HirRelationExpr::Map { scalars, .. } => {
66                    for scalar in scalars {
67                        walk_scalar(scalar)?;
68                    }
69                }
70                HirRelationExpr::CallTable { exprs, .. } => {
71                    for expr in exprs {
72                        walk_scalar(expr)?;
73                    }
74                }
75                HirRelationExpr::Filter { predicates, .. } => {
76                    let mut subqueries = vec![];
77                    for predicate in &mut *predicates {
78                        walk_scalar(predicate)?;
79                        extract_conjuncted_subqueries(predicate, &mut subqueries)?;
80                    }
81                    // TODO(benesch): we could be smarter about the order in which
82                    // we emit subqueries. At the moment we just emit in the order
83                    // we discovered them, but ideally we'd emit them in an order
84                    // that accounted for their cost/selectivity. E.g., low-cost,
85                    // high-selectivity subqueries should go first.
86                    for subquery in subqueries {
87                        predicates.push(subquery);
88                    }
89                }
90                _ => (),
91            }
92            Ok(())
93        })
94    }
95
96    fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
97        expr.try_visit_mut_post(&mut |expr| {
98            match expr {
99                HirScalarExpr::Exists(input, _name) | HirScalarExpr::Select(input, _name) => {
100                    walk_relation(input)?
101                }
102                _ => (),
103            }
104            Ok(())
105        })
106    }
107
108    fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
109        let mut found = false;
110        expr.visit_pre(&mut |expr| match expr {
111            HirScalarExpr::Exists(..) | HirScalarExpr::Select(..) => found = true,
112            _ => (),
113        })?;
114        Ok(found)
115    }
116
117    /// Extracts subqueries from a conjunction into `out`.
118    ///
119    /// For example, given an expression like
120    ///
121    /// ```text
122    /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
123    /// ```
124    ///
125    /// this function rewrites the expression to
126    ///
127    /// ```text
128    /// a = b AND true AND c = d AND true
129    /// ```
130    ///
131    /// and returns the expression fragments `EXISTS (<subquery 1>)` and
132    /// `(<subquery 2>) = e` in the `out` vector.
133    fn extract_conjuncted_subqueries(
134        expr: &mut HirScalarExpr,
135        out: &mut Vec<HirScalarExpr>,
136    ) -> Result<(), RecursionLimitError> {
137        match expr {
138            HirScalarExpr::CallVariadic {
139                func: VariadicFunc::And(_),
140                exprs,
141                name: _,
142            } => {
143                exprs
144                    .into_iter()
145                    .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
146            }
147            expr if contains_subquery(expr)? => {
148                out.push(mem::replace(expr, HirScalarExpr::literal_true()))
149            }
150            _ => (),
151        }
152        Ok(())
153    }
154
155    walk_relation(expr)
156}
157
158/// Rewrites quantified comparisons into simpler EXISTS operators.
159///
160/// Note that this transformation is only valid when the expression is
161/// used in a context where the distinction between `FALSE` and `NULL`
162/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
163/// when the inputs to the comparison are non-nullable. This function is careful
164/// to only apply the transformation when it is valid to do so.
165///
166/// ```ignore
167/// WHERE (SELECT any(<pred>) FROM <rel>)
168/// =>
169/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
170///
171/// WHERE (SELECT all(<pred>) FROM <rel>)
172/// =>
173/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
174/// ```
175///
176/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
177/// M. Elhemali, et al.
178pub fn try_simplify_quantified_comparisons(
179    expr: &mut HirRelationExpr,
180    simplify_join_on: bool,
181) -> Result<(), RecursionLimitError> {
182    fn walk_relation(
183        expr: &mut HirRelationExpr,
184        outers: &[SqlRelationType],
185        simplify_join_on: bool,
186    ) -> Result<(), RecursionLimitError> {
187        match expr {
188            HirRelationExpr::Map { scalars, input } => {
189                walk_relation(input, outers, simplify_join_on)?;
190                let mut outers = outers.to_vec();
191                outers.insert(0, input.typ(&outers, &NO_PARAMS));
192                for scalar in scalars {
193                    walk_scalar(scalar, &outers, false, simplify_join_on)?;
194                    let (inner, outers) = outers
195                        .split_first_mut()
196                        .expect("outers known to have at least one element");
197                    let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
198                    inner.column_types.push(scalar_type);
199                }
200            }
201            HirRelationExpr::Filter { predicates, input } => {
202                walk_relation(input, outers, simplify_join_on)?;
203                let mut outers = outers.to_vec();
204                outers.insert(0, input.typ(&outers, &NO_PARAMS));
205                for pred in predicates {
206                    walk_scalar(pred, &outers, true, simplify_join_on)?;
207                }
208            }
209            HirRelationExpr::CallTable { exprs, .. } => {
210                let mut outers = outers.to_vec();
211                outers.insert(0, SqlRelationType::empty());
212                for scalar in exprs {
213                    walk_scalar(scalar, &outers, false, simplify_join_on)?;
214                }
215            }
216            HirRelationExpr::Join {
217                left, right, on, ..
218            } => {
219                walk_relation(left, outers, simplify_join_on)?;
220                let left_type = left.typ(outers, &NO_PARAMS);
221                let mut outers = outers.to_vec();
222                outers.insert(0, left_type);
223                walk_relation(right, &outers, simplify_join_on)?;
224                if simplify_join_on {
225                    // Build outers with the full join output type, since the
226                    // ON clause can reference columns from both sides.
227                    let right_type = right.typ(&outers, &NO_PARAMS);
228                    let mut join_columns = outers[0].column_types.clone();
229                    join_columns.extend(right_type.column_types);
230                    outers[0] = SqlRelationType::new(join_columns);
231                    walk_scalar(on, &outers, true, simplify_join_on)?;
232                }
233            }
234            expr => {
235                #[allow(deprecated)]
236                let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
237                    walk_relation(expr, outers, simplify_join_on)
238                });
239            }
240        }
241        Ok(())
242    }
243
244    fn walk_scalar(
245        expr: &mut HirScalarExpr,
246        outers: &[SqlRelationType],
247        mut in_filter: bool,
248        simplify_join_on: bool,
249    ) -> Result<(), RecursionLimitError> {
250        expr.try_visit_mut_pre(&mut |e| {
251            match e {
252                HirScalarExpr::Exists(input, _name) => {
253                    walk_relation(input, outers, simplify_join_on)?
254                }
255                HirScalarExpr::Select(input, _name) => {
256                    walk_relation(input, outers, simplify_join_on)?;
257
258                    // We're inside a `(SELECT ...)` subquery. Now let's see if
259                    // it has the form `(SELECT <any|all>(...) FROM <input>)`.
260                    // Ideally we could do this with one pattern, but Rust's pattern
261                    // matching engine is not powerful enough, so we have to do this
262                    // in stages; the early returns avoid brutal nesting.
263
264                    let (func, expr, input) = match &mut **input {
265                        HirRelationExpr::Reduce {
266                            group_key,
267                            aggregates,
268                            input,
269                            expected_group_size: _,
270                        } if group_key.is_empty() && aggregates.len() == 1 => {
271                            let agg = &mut aggregates[0];
272                            (&agg.func, &mut agg.expr, input)
273                        }
274                        _ => return Ok(()),
275                    };
276
277                    if !in_filter && column_type(outers, input, expr).nullable {
278                        // Unless we're directly inside a WHERE, this
279                        // transformation is only valid if the expression involved
280                        // is non-nullable.
281                        return Ok(());
282                    }
283
284                    match func {
285                        AggregateFunc::Any => {
286                            // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
287                            // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
288                            *e = input.take().filter(vec![expr.take()]).exists();
289                        }
290                        AggregateFunc::All => {
291                            // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
292                            // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
293                            //
294                            // Note that negation of <expr> alone is insufficient.
295                            // Consider that `WHERE <pred>` filters out rows if
296                            // `<pred>` is false *or* null. To invert the test, we
297                            // need `NOT <pred> OR <pred> IS NULL`.
298                            let expr = expr.take();
299                            let filter = expr.clone().not().or(expr.call_is_null());
300                            *e = input.take().filter(vec![filter]).exists().not();
301                        }
302                        _ => (),
303                    }
304                }
305                _ => {
306                    // As soon as we see *any* scalar expression, we are no longer
307                    // directly inside a filter.
308                    in_filter = false;
309                }
310            }
311            Ok(())
312        })
313    }
314
315    walk_relation(expr, &[], simplify_join_on)
316}
317
318/// An empty parameter type map.
319///
320/// These transformations are expected to run after parameters are bound, so
321/// there is no need to provide any parameter type information.
322static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
323
324fn column_type(
325    outers: &[SqlRelationType],
326    inner: &HirRelationExpr,
327    expr: &HirScalarExpr,
328) -> SqlColumnType {
329    let inner_type = inner.typ(outers, &NO_PARAMS);
330    expr.typ(outers, &inner_type, &NO_PARAMS)
331}
332
333impl HirScalarExpr {
334    /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
335    /// considers column references that target the root level.
336    /// (See `visit_columns_referring_to_root_level`.)
337    fn support(&self) -> Vec<usize> {
338        let mut result = Vec::new();
339        self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
340        result
341    }
342
343    /// Changes column references in `self` by the given remapping.
344    /// Panics if a referred column is not present in `idx_map`!
345    fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
346        self.visit_columns_referring_to_root_level_mut(&mut |c| {
347            *c = idx_map[c];
348        });
349        self
350    }
351}
352
353/// # Aims and scope
354///
355/// The aim here is to amortize the overhead of the MIR window function pattern
356/// (see `window_func_applied_to`) by fusing groups of window function calls such
357/// that each group can be performed by one instance of the window function MIR
358/// pattern.
359///
360/// For now, we fuse only value window function calls and window aggregations.
361/// (We probably won't need to fuse scalar window functions for a long time.)
362///
363/// For now, we can fuse value window function calls and window aggregations where the
364/// A. partition by
365/// B. order by
366/// C. window frame
367/// D. ignore nulls for value window functions and distinct for window aggregations
368/// are all the same. (See `extract_options`.)
369/// (Later, we could improve this to only need A. to be the same. This would require
370/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
371/// TODO: As a much simpler intermediate step, at least we should ignore options that
372/// don't matter. For example, we should be able to fuse a `lag` that has a default
373/// frame with a `first_value` that has some custom frame, because `lag` is not
374/// affected by the frame.)
375/// Note that we fuse value window function calls and window aggregations separately.
376///
377/// # Implementation
378///
379/// At a high level, what we are going to do is look for Maps with more than one window function
380/// calls, and for each Map
381/// - remove some groups of window function call expressions from the Map's `scalars`;
382/// - insert a fused version of each group;
383/// - insert some expressions that decompose the results of the fused calls;
384/// - update some column references in `scalars`: those that refer to window function results that
385///   participated in fusion, as well as those that refer to columns that moved around due to
386///   removing and inserting expressions.
387/// - insert a Project above the matched Map to permute columns back to their original places.
388///
389/// It would be tempting to find groups simply by taking a list of all window function calls
390/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
391/// but a complication is that the possible groups that we could theoretically fuse overlap.
392/// This is because when forming groups we need to also take into account column references
393/// that point inside the same Map. For example, imagine a Map with the following scalar
394/// expressions:
395/// C1, E1, C2, C3, where
396/// - E1 refers to C1
397/// - C3 refers to E1.
398/// In this situation, we could either
399/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
400///   to it);
401/// - or fuse C2 and C3.
402/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
403/// no appropriate place for the fused expression: it would have to be both before and after E1.
404///
405/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
406/// we go through `scalars`, try to put each expression into each of our groups, and the first of
407/// these succeed. When trying to put an expression into a group, we need to be mindful about column
408/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
409/// sanity is that the fused version of each group will be inserted at the place where the first
410/// element of the group originally was. This means that the only condition that we need to check on
411/// column references when adding an expression to a group is that all column references in a group
412/// should be to columns that are earlier than the first element of the group. (No need to check
413/// column references in the other direction, i.e., references in other expressions that refer to
414/// columns in the group.)
415pub fn fuse_window_functions(
416    root: &mut HirRelationExpr,
417    _context: &crate::plan::lowering::Context,
418) -> Result<(), RecursionLimitError> {
419    /// Those options of a window function call that are relevant for fusion.
420    #[derive(PartialEq, Eq)]
421    enum WindowFuncCallOptions {
422        Value(ValueWindowFuncCallOptions),
423        Agg(AggregateWindowFuncCallOptions),
424    }
425    #[derive(PartialEq, Eq)]
426    struct ValueWindowFuncCallOptions {
427        partition_by: Vec<HirScalarExpr>,
428        outer_order_by: Vec<HirScalarExpr>,
429        inner_order_by: Vec<ColumnOrder>,
430        window_frame: WindowFrame,
431        ignore_nulls: bool,
432    }
433    #[derive(PartialEq, Eq)]
434    struct AggregateWindowFuncCallOptions {
435        partition_by: Vec<HirScalarExpr>,
436        outer_order_by: Vec<HirScalarExpr>,
437        inner_order_by: Vec<ColumnOrder>,
438        window_frame: WindowFrame,
439        distinct: bool,
440    }
441
442    /// Helper function to extract the above options.
443    fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
444        match call {
445            HirScalarExpr::Windowing(
446                WindowExpr {
447                    func:
448                        WindowExprType::Value(ValueWindowExpr {
449                            order_by: inner_order_by,
450                            window_frame,
451                            ignore_nulls,
452                            func: _,
453                            args: _,
454                        }),
455                    partition_by,
456                    order_by: outer_order_by,
457                },
458                _name,
459            ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
460                partition_by: partition_by.clone(),
461                outer_order_by: outer_order_by.clone(),
462                inner_order_by: inner_order_by.clone(),
463                window_frame: window_frame.clone(),
464                ignore_nulls: ignore_nulls.clone(),
465            }),
466            HirScalarExpr::Windowing(
467                WindowExpr {
468                    func:
469                        WindowExprType::Aggregate(AggregateWindowExpr {
470                            aggregate_expr:
471                                AggregateExpr {
472                                    distinct,
473                                    func: _,
474                                    expr: _,
475                                },
476                            order_by: inner_order_by,
477                            window_frame,
478                        }),
479                    partition_by,
480                    order_by: outer_order_by,
481                },
482                _name,
483            ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
484                partition_by: partition_by.clone(),
485                outer_order_by: outer_order_by.clone(),
486                inner_order_by: inner_order_by.clone(),
487                window_frame: window_frame.clone(),
488                distinct: distinct.clone(),
489            }),
490            _ => panic!(
491                "extract_options should only be called on value window functions or window aggregations"
492            ),
493        }
494    }
495
496    struct FusionGroup {
497        /// The original column index of the first element of the group. (This is an index into the
498        /// Map's `scalars` plus the arity of the Map's input.)
499        first_col: usize,
500        /// The options of all the window function calls in the group. (Must be the same for all the
501        /// calls.)
502        options: WindowFuncCallOptions,
503        /// The calls in the group, with their original column indexes.
504        calls: Vec<(usize, HirScalarExpr)>,
505    }
506
507    impl FusionGroup {
508        /// Creates a window function call that is a fused version of all the calls in the group.
509        /// `new_col` is the column index where the fused call will be inserted at.
510        fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
511            let fused = match self.options {
512                WindowFuncCallOptions::Value(options) => {
513                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
514                        .calls
515                        .iter()
516                        .map(|(_idx, call)| {
517                            if let HirScalarExpr::Windowing(
518                                WindowExpr {
519                                    func:
520                                        WindowExprType::Value(ValueWindowExpr {
521                                            func,
522                                            args,
523                                            order_by: _,
524                                            window_frame: _,
525                                            ignore_nulls: _,
526                                        }),
527                                    partition_by: _,
528                                    order_by: _,
529                                },
530                                _name,
531                            ) = call
532                            {
533                                (func.clone(), (**args).clone())
534                            } else {
535                                panic!("unknown window function in FusionGroup")
536                            }
537                        })
538                        .unzip();
539                    let fused_args = HirScalarExpr::call_variadic(
540                        RecordCreate {
541                            // These field names are not important, because this record will only be an
542                            // intermediate expression, which we'll manipulate further before it ends up
543                            // anywhere where a column name would be visible.
544                            field_names: iter::repeat(ColumnName::from(""))
545                                .take(fused_args.len())
546                                .collect(),
547                        },
548                        fused_args,
549                    );
550                    HirScalarExpr::windowing(WindowExpr {
551                        func: WindowExprType::Value(ValueWindowExpr {
552                            func: ValueWindowFunc::Fused(fused_funcs),
553                            args: Box::new(fused_args),
554                            order_by: options.inner_order_by,
555                            window_frame: options.window_frame,
556                            ignore_nulls: options.ignore_nulls,
557                        }),
558                        partition_by: options.partition_by,
559                        order_by: options.outer_order_by,
560                    })
561                }
562                WindowFuncCallOptions::Agg(options) => {
563                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
564                        .calls
565                        .iter()
566                        .map(|(_idx, call)| {
567                            if let HirScalarExpr::Windowing(
568                                WindowExpr {
569                                    func:
570                                        WindowExprType::Aggregate(AggregateWindowExpr {
571                                            aggregate_expr:
572                                                AggregateExpr {
573                                                    func,
574                                                    expr,
575                                                    distinct: _,
576                                                },
577                                            order_by: _,
578                                            window_frame: _,
579                                        }),
580                                    partition_by: _,
581                                    order_by: _,
582                                },
583                                _name,
584                            ) = call
585                            {
586                                (func.clone(), (**expr).clone())
587                            } else {
588                                panic!("unknown window function in FusionGroup")
589                            }
590                        })
591                        .unzip();
592                    let fused_args = HirScalarExpr::call_variadic(
593                        RecordCreate {
594                            field_names: iter::repeat(ColumnName::from(""))
595                                .take(fused_args.len())
596                                .collect(),
597                        },
598                        fused_args,
599                    );
600                    HirScalarExpr::windowing(WindowExpr {
601                        func: WindowExprType::Aggregate(AggregateWindowExpr {
602                            aggregate_expr: AggregateExpr {
603                                func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
604                                expr: Box::new(fused_args),
605                                distinct: options.distinct,
606                            },
607                            order_by: options.inner_order_by,
608                            window_frame: options.window_frame,
609                        }),
610                        partition_by: options.partition_by,
611                        order_by: options.outer_order_by,
612                    })
613                }
614            };
615
616            let decompositions = (0..self.calls.len())
617                .map(|field| {
618                    HirScalarExpr::column(new_col)
619                        .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
620                })
621                .collect();
622
623            (fused, decompositions)
624        }
625    }
626
627    let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
628        // Look for calls only at the root of scalar expressions. This is enough
629        // because they are always there, see 72e84bb78.
630        match scalar_expr {
631            HirScalarExpr::Windowing(
632                WindowExpr {
633                    func: WindowExprType::Value(ValueWindowExpr { func, .. }),
634                    ..
635                },
636                _name,
637            ) => {
638                // Exclude those calls that are already fused. (We shouldn't currently
639                // encounter these, because we just do one pass, but it's better to be
640                // robust against future code changes.)
641                !matches!(func, ValueWindowFunc::Fused(..))
642            }
643            HirScalarExpr::Windowing(
644                WindowExpr {
645                    func:
646                        WindowExprType::Aggregate(AggregateWindowExpr {
647                            aggregate_expr: AggregateExpr { func, .. },
648                            ..
649                        }),
650                    ..
651                },
652                _name,
653            ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
654            _ => false,
655        }
656    };
657
658    root.try_visit_mut_post(&mut |rel_expr| {
659        match rel_expr {
660            HirRelationExpr::Map { input, scalars } => {
661                // There will be various variable names involving `idx` or `col`:
662                // - `idx` will always be an index into `scalars` or something similar,
663                // - `col` will always be a column index,
664                //   which is often `arity_before_map` + an index into `scalars`.
665                let arity_before_map = input.arity();
666                let orig_num_scalars = scalars.len();
667
668                // Collect all value window function calls and window aggregations with their column
669                // indexes.
670                let value_or_agg_window_func_calls = scalars
671                    .iter()
672                    .enumerate()
673                    .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
674                    .map(|(idx, call)| (idx + arity_before_map, call.clone()))
675                    .collect_vec();
676                // Exit early if obviously no chance for fusion.
677                if value_or_agg_window_func_calls.len() <= 1 {
678                    // Note that we are doing this only for performance. All plans should be exactly
679                    // the same even if we comment out the following line.
680                    return Ok(());
681                }
682
683                // Determine the fusion groups. (Each group will later be fused into one window
684                // function call.)
685                // Note that this has a quadratic run time with value_or_agg_window_func_calls in
686                // the worst case. However, this is fine even with 1000 window function calls.
687                let mut groups: Vec<FusionGroup> = Vec::new();
688                for (col, call) in value_or_agg_window_func_calls {
689                    let options = extract_options(&call);
690                    let support = call.support();
691                    let to_fuse_with = groups
692                        .iter_mut()
693                        .filter(|group| {
694                            group.options == options && support.iter().all(|c| *c < group.first_col)
695                        })
696                        .next();
697                    if let Some(group) = to_fuse_with {
698                        group.calls.push((col, call.clone()));
699                    } else {
700                        groups.push(FusionGroup {
701                            first_col: col,
702                            options,
703                            calls: vec![(col, call.clone())],
704                        });
705                    }
706                }
707
708                // No fusion to do on groups of 1.
709                groups.retain(|g| g.calls.len() > 1);
710
711                let removals: BTreeSet<usize> = groups
712                    .iter()
713                    .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
714                    .collect();
715
716                // Mutate `scalars`.
717                // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
718                // `groups` is already sorted by `first_col` due to the way it was constructed.)
719                // We also compute a remapping of old indexes to new indexes as we go.
720                let mut groups_it = groups.drain(..).peekable();
721                let mut group = groups_it.next();
722                let mut remap = BTreeMap::new();
723                remap.extend((0..arity_before_map).map(|col| (col, col)));
724                let mut new_col: usize = arity_before_map;
725                let mut new_scalars = Vec::new();
726                for (old_col, e) in scalars
727                    .drain(..)
728                    .enumerate()
729                    .map(|(idx, e)| (idx + arity_before_map, e))
730                {
731                    if group.as_ref().is_some_and(|g| g.first_col == old_col) {
732                        // The current expression will be fused away, and a fused expression will
733                        // appear in its place. Additionally, some new expressions will be inserted
734                        // after the fused expression, to decompose the record that is the result of
735                        // the fused call.
736                        assert!(removals.contains(&old_col));
737                        let group_unwrapped = group.expect("checked above");
738                        let calls_cols = group_unwrapped
739                            .calls
740                            .iter()
741                            .map(|(col, _call)| *col)
742                            .collect_vec();
743                        let (fused, decompositions) = group_unwrapped.fuse(new_col);
744                        new_scalars.push(fused.remap(&remap));
745                        new_scalars.extend(decompositions); // (no remapping needed)
746                        new_col += 1;
747                        for call_old_col in calls_cols {
748                            let present = remap.insert(call_old_col, new_col);
749                            assert!(present.is_none());
750                            new_col += 1;
751                        }
752                        group = groups_it.next();
753                    } else if removals.contains(&old_col) {
754                        assert!(remap.contains_key(&old_col));
755                    } else {
756                        new_scalars.push(e.remap(&remap));
757                        let present = remap.insert(old_col, new_col);
758                        assert!(present.is_none());
759                        new_col += 1;
760                    }
761                }
762                *scalars = new_scalars;
763                assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
764
765                // Add a project to permute columns back to their original places.
766                *rel_expr = rel_expr.take().project(
767                    (0..arity_before_map)
768                        .chain((0..orig_num_scalars).map(|idx| {
769                            *remap
770                                .get(&(idx + arity_before_map))
771                                .expect("all columns should be present by now")
772                        }))
773                        .collect(),
774                );
775
776                assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
777            }
778            _ => {}
779        }
780        Ok(())
781    })
782}