Skip to main content

mz_sql/plan/
transform_hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::func::variadic::RecordCreate;
19use mz_expr::visit::Visit;
20use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
21use mz_ore::stack::RecursionLimitError;
22use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
23
24use crate::plan::hir::{
25    AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
26    ValueWindowExpr, ValueWindowFunc, WindowExpr,
27};
28use crate::plan::{AggregateExpr, WindowExprType};
29
30/// Rewrites predicates that contain subqueries so that the subqueries
31/// appear in their own later predicate when possible.
32///
33/// For example, this function rewrites this expression
34///
35/// ```text
36/// Filter {
37///     predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
38/// }
39/// ```
40///
41/// like so:
42///
43/// ```text
44/// Filter {
45///     predicates: [
46///         a = b AND c = d,
47///         EXISTS (<subquery>),
48///         (<subquery 2>) = e,
49///     ]
50/// }
51/// ```
52///
53/// The rewrite causes decorrelation to incorporate prior predicates into
54/// the outer relation upon which the subquery is evaluated. In the above
55/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
56/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
57/// = e`, will be further restricted to outer rows that match `A = b AND c =
58/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
59/// subquery, especially when the original conjunction contains join keys.
60pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61    fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
62        #[allow(deprecated)]
63        expr.visit_mut_fallible(0, &mut |expr, _| {
64            match expr {
65                HirRelationExpr::Map { scalars, .. } => {
66                    for scalar in scalars {
67                        walk_scalar(scalar)?;
68                    }
69                }
70                HirRelationExpr::CallTable { exprs, .. } => {
71                    for expr in exprs {
72                        walk_scalar(expr)?;
73                    }
74                }
75                HirRelationExpr::Filter { predicates, .. } => {
76                    let mut subqueries = vec![];
77                    for predicate in &mut *predicates {
78                        walk_scalar(predicate)?;
79                        extract_conjuncted_subqueries(predicate, &mut subqueries)?;
80                    }
81                    // TODO(benesch): we could be smarter about the order in which
82                    // we emit subqueries. At the moment we just emit in the order
83                    // we discovered them, but ideally we'd emit them in an order
84                    // that accounted for their cost/selectivity. E.g., low-cost,
85                    // high-selectivity subqueries should go first.
86                    for subquery in subqueries {
87                        predicates.push(subquery);
88                    }
89                }
90                _ => (),
91            }
92            Ok(())
93        })
94    }
95
96    fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
97        expr.try_visit_direct_subqueries_mut(&mut walk_relation)
98    }
99
100    fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
101        let mut found = false;
102        expr.try_visit_direct_subqueries(|_| {
103            found = true;
104            Ok(())
105        })?;
106        Ok(found)
107    }
108
109    /// Extracts subqueries from a conjunction into `out`.
110    ///
111    /// For example, given an expression like
112    ///
113    /// ```text
114    /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
115    /// ```
116    ///
117    /// this function rewrites the expression to
118    ///
119    /// ```text
120    /// a = b AND true AND c = d AND true
121    /// ```
122    ///
123    /// and returns the expression fragments `EXISTS (<subquery 1>)` and
124    /// `(<subquery 2>) = e` in the `out` vector.
125    fn extract_conjuncted_subqueries(
126        expr: &mut HirScalarExpr,
127        out: &mut Vec<HirScalarExpr>,
128    ) -> Result<(), RecursionLimitError> {
129        match expr {
130            HirScalarExpr::CallVariadic {
131                func: VariadicFunc::And(_),
132                exprs,
133                name: _,
134            } => {
135                exprs
136                    .into_iter()
137                    .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
138            }
139            expr if contains_subquery(expr)? => {
140                out.push(mem::replace(expr, HirScalarExpr::literal_true()))
141            }
142            _ => (),
143        }
144        Ok(())
145    }
146
147    walk_relation(expr)
148}
149
150/// Rewrites quantified comparisons into simpler EXISTS operators.
151///
152/// Note that this transformation is only valid when the expression is
153/// used in a context where the distinction between `FALSE` and `NULL`
154/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
155/// when the inputs to the comparison are non-nullable. This function is careful
156/// to only apply the transformation when it is valid to do so.
157///
158/// ```ignore
159/// WHERE (SELECT any(<pred>) FROM <rel>)
160/// =>
161/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
162///
163/// WHERE (SELECT all(<pred>) FROM <rel>)
164/// =>
165/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
166/// ```
167///
168/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
169/// M. Elhemali, et al.
170pub fn try_simplify_quantified_comparisons(
171    expr: &mut HirRelationExpr,
172    simplify_join_on: bool,
173) -> Result<(), RecursionLimitError> {
174    fn walk_relation(
175        expr: &mut HirRelationExpr,
176        outers: &[SqlRelationType],
177        simplify_join_on: bool,
178    ) -> Result<(), RecursionLimitError> {
179        match expr {
180            HirRelationExpr::Map { scalars, input } => {
181                walk_relation(input, outers, simplify_join_on)?;
182                let mut outers = outers.to_vec();
183                outers.insert(0, input.typ(&outers, &NO_PARAMS));
184                for scalar in scalars {
185                    walk_scalar(scalar, &outers, false, simplify_join_on)?;
186                    let (inner, outers) = outers
187                        .split_first_mut()
188                        .expect("outers known to have at least one element");
189                    let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
190                    inner.column_types.push(scalar_type);
191                }
192            }
193            HirRelationExpr::Filter { predicates, input } => {
194                walk_relation(input, outers, simplify_join_on)?;
195                let mut outers = outers.to_vec();
196                outers.insert(0, input.typ(&outers, &NO_PARAMS));
197                for pred in predicates {
198                    walk_scalar(pred, &outers, true, simplify_join_on)?;
199                }
200            }
201            HirRelationExpr::CallTable { exprs, .. } => {
202                let mut outers = outers.to_vec();
203                outers.insert(0, SqlRelationType::empty());
204                for scalar in exprs {
205                    walk_scalar(scalar, &outers, false, simplify_join_on)?;
206                }
207            }
208            HirRelationExpr::Join {
209                left, right, on, ..
210            } => {
211                walk_relation(left, outers, simplify_join_on)?;
212                let left_type = left.typ(outers, &NO_PARAMS);
213                let mut outers = outers.to_vec();
214                outers.insert(0, left_type);
215                walk_relation(right, &outers, simplify_join_on)?;
216                if simplify_join_on {
217                    // Build outers with the full join output type, since the
218                    // ON clause can reference columns from both sides.
219                    let right_type = right.typ(&outers, &NO_PARAMS);
220                    let mut join_columns = outers[0].column_types.clone();
221                    join_columns.extend(right_type.column_types);
222                    outers[0] = SqlRelationType::new(join_columns);
223                    walk_scalar(on, &outers, true, simplify_join_on)?;
224                }
225            }
226            expr => {
227                #[allow(deprecated)]
228                let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
229                    walk_relation(expr, outers, simplify_join_on)
230                });
231            }
232        }
233        Ok(())
234    }
235
236    fn walk_scalar(
237        expr: &mut HirScalarExpr,
238        outers: &[SqlRelationType],
239        mut in_filter: bool,
240        simplify_join_on: bool,
241    ) -> Result<(), RecursionLimitError> {
242        expr.try_visit_mut_pre(&mut |e| {
243            match e {
244                HirScalarExpr::Exists(input, _name) => {
245                    walk_relation(input, outers, simplify_join_on)?
246                }
247                HirScalarExpr::Select(input, _name) => {
248                    walk_relation(input, outers, simplify_join_on)?;
249
250                    // We're inside a `(SELECT ...)` subquery. Now let's see if
251                    // it has the form `(SELECT <any|all>(...) FROM <input>)`.
252                    // Ideally we could do this with one pattern, but Rust's pattern
253                    // matching engine is not powerful enough, so we have to do this
254                    // in stages; the early returns avoid brutal nesting.
255
256                    let (func, expr, input) = match &mut **input {
257                        HirRelationExpr::Reduce {
258                            group_key,
259                            aggregates,
260                            input,
261                            expected_group_size: _,
262                        } if group_key.is_empty() && aggregates.len() == 1 => {
263                            let agg = &mut aggregates[0];
264                            (&agg.func, &mut agg.expr, input)
265                        }
266                        _ => return Ok(()),
267                    };
268
269                    if !in_filter && column_type(outers, input, expr).nullable {
270                        // Unless we're directly inside a WHERE, this
271                        // transformation is only valid if the expression involved
272                        // is non-nullable.
273                        return Ok(());
274                    }
275
276                    match func {
277                        AggregateFunc::Any => {
278                            // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
279                            // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
280                            *e = input.take().filter(vec![expr.take()]).exists();
281                        }
282                        AggregateFunc::All => {
283                            // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
284                            // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
285                            //
286                            // Note that negation of <expr> alone is insufficient.
287                            // Consider that `WHERE <pred>` filters out rows if
288                            // `<pred>` is false *or* null. To invert the test, we
289                            // need `NOT <pred> OR <pred> IS NULL`.
290                            let expr = expr.take();
291                            let filter = expr.clone().not().or(expr.call_is_null());
292                            *e = input.take().filter(vec![filter]).exists().not();
293                        }
294                        _ => (),
295                    }
296                }
297                _ => {
298                    // As soon as we see *any* scalar expression, we are no longer
299                    // directly inside a filter.
300                    in_filter = false;
301                }
302            }
303            Ok(())
304        })
305    }
306
307    walk_relation(expr, &[], simplify_join_on)
308}
309
310/// An empty parameter type map.
311///
312/// These transformations are expected to run after parameters are bound, so
313/// there is no need to provide any parameter type information.
314static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
315
316fn column_type(
317    outers: &[SqlRelationType],
318    inner: &HirRelationExpr,
319    expr: &HirScalarExpr,
320) -> SqlColumnType {
321    let inner_type = inner.typ(outers, &NO_PARAMS);
322    expr.typ(outers, &inner_type, &NO_PARAMS)
323}
324
325impl HirScalarExpr {
326    /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
327    /// considers column references that target the root level.
328    /// (See `visit_columns_referring_to_root_level`.)
329    fn support(&self) -> Vec<usize> {
330        let mut result = Vec::new();
331        self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
332        result
333    }
334
335    /// Changes column references in `self` by the given remapping.
336    /// Panics if a referred column is not present in `idx_map`!
337    fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
338        self.visit_columns_referring_to_root_level_mut(&mut |c| {
339            *c = idx_map[c];
340        });
341        self
342    }
343}
344
345/// # Aims and scope
346///
347/// The aim here is to amortize the overhead of the MIR window function pattern
348/// (see `window_func_applied_to`) by fusing groups of window function calls such
349/// that each group can be performed by one instance of the window function MIR
350/// pattern.
351///
352/// For now, we fuse only value window function calls and window aggregations.
353/// (We probably won't need to fuse scalar window functions for a long time.)
354///
355/// For now, we can fuse value window function calls and window aggregations where the
356/// A. partition by
357/// B. order by
358/// C. window frame
359/// D. ignore nulls for value window functions and distinct for window aggregations
360/// are all the same. (See `extract_options`.)
361/// (Later, we could improve this to only need A. to be the same. This would require
362/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
363/// TODO: As a much simpler intermediate step, at least we should ignore options that
364/// don't matter. For example, we should be able to fuse a `lag` that has a default
365/// frame with a `first_value` that has some custom frame, because `lag` is not
366/// affected by the frame.)
367/// Note that we fuse value window function calls and window aggregations separately.
368///
369/// # Implementation
370///
371/// At a high level, what we are going to do is look for Maps with more than one window function
372/// calls, and for each Map
373/// - remove some groups of window function call expressions from the Map's `scalars`;
374/// - insert a fused version of each group;
375/// - insert some expressions that decompose the results of the fused calls;
376/// - update some column references in `scalars`: those that refer to window function results that
377///   participated in fusion, as well as those that refer to columns that moved around due to
378///   removing and inserting expressions.
379/// - insert a Project above the matched Map to permute columns back to their original places.
380///
381/// It would be tempting to find groups simply by taking a list of all window function calls
382/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
383/// but a complication is that the possible groups that we could theoretically fuse overlap.
384/// This is because when forming groups we need to also take into account column references
385/// that point inside the same Map. For example, imagine a Map with the following scalar
386/// expressions:
387/// C1, E1, C2, C3, where
388/// - E1 refers to C1
389/// - C3 refers to E1.
390/// In this situation, we could either
391/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
392///   to it);
393/// - or fuse C2 and C3.
394/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
395/// no appropriate place for the fused expression: it would have to be both before and after E1.
396///
397/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
398/// we go through `scalars`, try to put each expression into each of our groups, and the first of
399/// these succeed. When trying to put an expression into a group, we need to be mindful about column
400/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
401/// sanity is that the fused version of each group will be inserted at the place where the first
402/// element of the group originally was. This means that the only condition that we need to check on
403/// column references when adding an expression to a group is that all column references in a group
404/// should be to columns that are earlier than the first element of the group. (No need to check
405/// column references in the other direction, i.e., references in other expressions that refer to
406/// columns in the group.)
407pub fn fuse_window_functions(
408    root: &mut HirRelationExpr,
409    _context: &crate::plan::lowering::Context,
410) -> Result<(), RecursionLimitError> {
411    /// Those options of a window function call that are relevant for fusion.
412    #[derive(PartialEq, Eq)]
413    enum WindowFuncCallOptions {
414        Value(ValueWindowFuncCallOptions),
415        Agg(AggregateWindowFuncCallOptions),
416    }
417    #[derive(PartialEq, Eq)]
418    struct ValueWindowFuncCallOptions {
419        partition_by: Vec<HirScalarExpr>,
420        outer_order_by: Vec<HirScalarExpr>,
421        inner_order_by: Vec<ColumnOrder>,
422        window_frame: WindowFrame,
423        ignore_nulls: bool,
424    }
425    #[derive(PartialEq, Eq)]
426    struct AggregateWindowFuncCallOptions {
427        partition_by: Vec<HirScalarExpr>,
428        outer_order_by: Vec<HirScalarExpr>,
429        inner_order_by: Vec<ColumnOrder>,
430        window_frame: WindowFrame,
431        distinct: bool,
432    }
433
434    /// Helper function to extract the above options.
435    fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
436        match call {
437            HirScalarExpr::Windowing(
438                WindowExpr {
439                    func:
440                        WindowExprType::Value(ValueWindowExpr {
441                            order_by: inner_order_by,
442                            window_frame,
443                            ignore_nulls,
444                            func: _,
445                            args: _,
446                        }),
447                    partition_by,
448                    order_by: outer_order_by,
449                },
450                _name,
451            ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
452                partition_by: partition_by.clone(),
453                outer_order_by: outer_order_by.clone(),
454                inner_order_by: inner_order_by.clone(),
455                window_frame: window_frame.clone(),
456                ignore_nulls: ignore_nulls.clone(),
457            }),
458            HirScalarExpr::Windowing(
459                WindowExpr {
460                    func:
461                        WindowExprType::Aggregate(AggregateWindowExpr {
462                            aggregate_expr:
463                                AggregateExpr {
464                                    distinct,
465                                    func: _,
466                                    expr: _,
467                                },
468                            order_by: inner_order_by,
469                            window_frame,
470                        }),
471                    partition_by,
472                    order_by: outer_order_by,
473                },
474                _name,
475            ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
476                partition_by: partition_by.clone(),
477                outer_order_by: outer_order_by.clone(),
478                inner_order_by: inner_order_by.clone(),
479                window_frame: window_frame.clone(),
480                distinct: distinct.clone(),
481            }),
482            _ => panic!(
483                "extract_options should only be called on value window functions or window aggregations"
484            ),
485        }
486    }
487
488    struct FusionGroup {
489        /// The original column index of the first element of the group. (This is an index into the
490        /// Map's `scalars` plus the arity of the Map's input.)
491        first_col: usize,
492        /// The options of all the window function calls in the group. (Must be the same for all the
493        /// calls.)
494        options: WindowFuncCallOptions,
495        /// The calls in the group, with their original column indexes.
496        calls: Vec<(usize, HirScalarExpr)>,
497    }
498
499    impl FusionGroup {
500        /// Creates a window function call that is a fused version of all the calls in the group.
501        /// `new_col` is the column index where the fused call will be inserted at.
502        fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
503            let fused = match self.options {
504                WindowFuncCallOptions::Value(options) => {
505                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
506                        .calls
507                        .iter()
508                        .map(|(_idx, call)| {
509                            if let HirScalarExpr::Windowing(
510                                WindowExpr {
511                                    func:
512                                        WindowExprType::Value(ValueWindowExpr {
513                                            func,
514                                            args,
515                                            order_by: _,
516                                            window_frame: _,
517                                            ignore_nulls: _,
518                                        }),
519                                    partition_by: _,
520                                    order_by: _,
521                                },
522                                _name,
523                            ) = call
524                            {
525                                (func.clone(), (**args).clone())
526                            } else {
527                                panic!("unknown window function in FusionGroup")
528                            }
529                        })
530                        .unzip();
531                    let fused_args = HirScalarExpr::call_variadic(
532                        RecordCreate {
533                            // These field names are not important, because this record will only be an
534                            // intermediate expression, which we'll manipulate further before it ends up
535                            // anywhere where a column name would be visible.
536                            field_names: iter::repeat(ColumnName::from(""))
537                                .take(fused_args.len())
538                                .collect(),
539                        },
540                        fused_args,
541                    );
542                    HirScalarExpr::windowing(WindowExpr {
543                        func: WindowExprType::Value(ValueWindowExpr {
544                            func: ValueWindowFunc::Fused(fused_funcs),
545                            args: Box::new(fused_args),
546                            order_by: options.inner_order_by,
547                            window_frame: options.window_frame,
548                            ignore_nulls: options.ignore_nulls,
549                        }),
550                        partition_by: options.partition_by,
551                        order_by: options.outer_order_by,
552                    })
553                }
554                WindowFuncCallOptions::Agg(options) => {
555                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
556                        .calls
557                        .iter()
558                        .map(|(_idx, call)| {
559                            if let HirScalarExpr::Windowing(
560                                WindowExpr {
561                                    func:
562                                        WindowExprType::Aggregate(AggregateWindowExpr {
563                                            aggregate_expr:
564                                                AggregateExpr {
565                                                    func,
566                                                    expr,
567                                                    distinct: _,
568                                                },
569                                            order_by: _,
570                                            window_frame: _,
571                                        }),
572                                    partition_by: _,
573                                    order_by: _,
574                                },
575                                _name,
576                            ) = call
577                            {
578                                (func.clone(), (**expr).clone())
579                            } else {
580                                panic!("unknown window function in FusionGroup")
581                            }
582                        })
583                        .unzip();
584                    let fused_args = HirScalarExpr::call_variadic(
585                        RecordCreate {
586                            field_names: iter::repeat(ColumnName::from(""))
587                                .take(fused_args.len())
588                                .collect(),
589                        },
590                        fused_args,
591                    );
592                    HirScalarExpr::windowing(WindowExpr {
593                        func: WindowExprType::Aggregate(AggregateWindowExpr {
594                            aggregate_expr: AggregateExpr {
595                                func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
596                                expr: Box::new(fused_args),
597                                distinct: options.distinct,
598                            },
599                            order_by: options.inner_order_by,
600                            window_frame: options.window_frame,
601                        }),
602                        partition_by: options.partition_by,
603                        order_by: options.outer_order_by,
604                    })
605                }
606            };
607
608            let decompositions = (0..self.calls.len())
609                .map(|field| {
610                    HirScalarExpr::column(new_col)
611                        .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
612                })
613                .collect();
614
615            (fused, decompositions)
616        }
617    }
618
619    let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
620        // Look for calls only at the root of scalar expressions. This is enough
621        // because they are always there, see 72e84bb78.
622        match scalar_expr {
623            HirScalarExpr::Windowing(
624                WindowExpr {
625                    func: WindowExprType::Value(ValueWindowExpr { func, .. }),
626                    ..
627                },
628                _name,
629            ) => {
630                // Exclude those calls that are already fused. (We shouldn't currently
631                // encounter these, because we just do one pass, but it's better to be
632                // robust against future code changes.)
633                !matches!(func, ValueWindowFunc::Fused(..))
634            }
635            HirScalarExpr::Windowing(
636                WindowExpr {
637                    func:
638                        WindowExprType::Aggregate(AggregateWindowExpr {
639                            aggregate_expr: AggregateExpr { func, .. },
640                            ..
641                        }),
642                    ..
643                },
644                _name,
645            ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
646            _ => false,
647        }
648    };
649
650    root.try_visit_mut_post(&mut |rel_expr| {
651        match rel_expr {
652            HirRelationExpr::Map { input, scalars } => {
653                // There will be various variable names involving `idx` or `col`:
654                // - `idx` will always be an index into `scalars` or something similar,
655                // - `col` will always be a column index,
656                //   which is often `arity_before_map` + an index into `scalars`.
657                let arity_before_map = input.arity();
658                let orig_num_scalars = scalars.len();
659
660                // Collect all value window function calls and window aggregations with their column
661                // indexes.
662                let value_or_agg_window_func_calls = scalars
663                    .iter()
664                    .enumerate()
665                    .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
666                    .map(|(idx, call)| (idx + arity_before_map, call.clone()))
667                    .collect_vec();
668                // Exit early if obviously no chance for fusion.
669                if value_or_agg_window_func_calls.len() <= 1 {
670                    // Note that we are doing this only for performance. All plans should be exactly
671                    // the same even if we comment out the following line.
672                    return Ok(());
673                }
674
675                // Determine the fusion groups. (Each group will later be fused into one window
676                // function call.)
677                // Note that this has a quadratic run time with value_or_agg_window_func_calls in
678                // the worst case. However, this is fine even with 1000 window function calls.
679                let mut groups: Vec<FusionGroup> = Vec::new();
680                for (col, call) in value_or_agg_window_func_calls {
681                    let options = extract_options(&call);
682                    let support = call.support();
683                    let to_fuse_with = groups
684                        .iter_mut()
685                        .filter(|group| {
686                            group.options == options && support.iter().all(|c| *c < group.first_col)
687                        })
688                        .next();
689                    if let Some(group) = to_fuse_with {
690                        group.calls.push((col, call.clone()));
691                    } else {
692                        groups.push(FusionGroup {
693                            first_col: col,
694                            options,
695                            calls: vec![(col, call.clone())],
696                        });
697                    }
698                }
699
700                // No fusion to do on groups of 1.
701                groups.retain(|g| g.calls.len() > 1);
702
703                let removals: BTreeSet<usize> = groups
704                    .iter()
705                    .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
706                    .collect();
707
708                // Mutate `scalars`.
709                // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
710                // `groups` is already sorted by `first_col` due to the way it was constructed.)
711                // We also compute a remapping of old indexes to new indexes as we go.
712                let mut groups_it = groups.drain(..).peekable();
713                let mut group = groups_it.next();
714                let mut remap = BTreeMap::new();
715                remap.extend((0..arity_before_map).map(|col| (col, col)));
716                let mut new_col: usize = arity_before_map;
717                let mut new_scalars = Vec::new();
718                for (old_col, e) in scalars
719                    .drain(..)
720                    .enumerate()
721                    .map(|(idx, e)| (idx + arity_before_map, e))
722                {
723                    if group.as_ref().is_some_and(|g| g.first_col == old_col) {
724                        // The current expression will be fused away, and a fused expression will
725                        // appear in its place. Additionally, some new expressions will be inserted
726                        // after the fused expression, to decompose the record that is the result of
727                        // the fused call.
728                        assert!(removals.contains(&old_col));
729                        let group_unwrapped = group.expect("checked above");
730                        let calls_cols = group_unwrapped
731                            .calls
732                            .iter()
733                            .map(|(col, _call)| *col)
734                            .collect_vec();
735                        let (fused, decompositions) = group_unwrapped.fuse(new_col);
736                        new_scalars.push(fused.remap(&remap));
737                        new_scalars.extend(decompositions); // (no remapping needed)
738                        new_col += 1;
739                        for call_old_col in calls_cols {
740                            let present = remap.insert(call_old_col, new_col);
741                            assert!(present.is_none());
742                            new_col += 1;
743                        }
744                        group = groups_it.next();
745                    } else if removals.contains(&old_col) {
746                        assert!(remap.contains_key(&old_col));
747                    } else {
748                        new_scalars.push(e.remap(&remap));
749                        let present = remap.insert(old_col, new_col);
750                        assert!(present.is_none());
751                        new_col += 1;
752                    }
753                }
754                *scalars = new_scalars;
755                assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
756
757                // Add a project to permute columns back to their original places.
758                *rel_expr = rel_expr.take().project(
759                    (0..arity_before_map)
760                        .chain((0..orig_num_scalars).map(|idx| {
761                            *remap
762                                .get(&(idx + arity_before_map))
763                                .expect("all columns should be present by now")
764                        }))
765                        .collect(),
766                );
767
768                assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
769            }
770            _ => {}
771        }
772        Ok(())
773    })
774}