mz_sql/plan/
transform_hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::visit::Visit;
19use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
20use mz_ore::stack::RecursionLimitError;
21use mz_repr::{ColumnName, ColumnType, RelationType, ScalarType};
22
23use crate::plan::hir::{
24    AbstractExpr, AggregateFunc, AggregateWindowExpr, ColumnRef, HirRelationExpr, HirScalarExpr,
25    ValueWindowExpr, ValueWindowFunc, WindowExpr,
26};
27use crate::plan::{AggregateExpr, WindowExprType};
28
29/// Rewrites predicates that contain subqueries so that the subqueries
30/// appear in their own later predicate when possible.
31///
32/// For example, this function rewrites this expression
33///
34/// ```text
35/// Filter {
36///     predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
37/// }
38/// ```
39///
40/// like so:
41///
42/// ```text
43/// Filter {
44///     predicates: [
45///         a = b AND c = d,
46///         EXISTS (<subquery>),
47///         (<subquery 2>) = e,
48///     ]
49/// }
50/// ```
51///
52/// The rewrite causes decorrelation to incorporate prior predicates into
53/// the outer relation upon which the subquery is evaluated. In the above
54/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
55/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
56/// = e`, will be further restricted to outer rows that match `A = b AND c =
57/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
58/// subquery, especially when the original conjunction contains join keys.
59pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
60    fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61        #[allow(deprecated)]
62        expr.visit_mut_fallible(0, &mut |expr, _| {
63            match expr {
64                HirRelationExpr::Map { scalars, .. } => {
65                    for scalar in scalars {
66                        walk_scalar(scalar)?;
67                    }
68                }
69                HirRelationExpr::CallTable { exprs, .. } => {
70                    for expr in exprs {
71                        walk_scalar(expr)?;
72                    }
73                }
74                HirRelationExpr::Filter { predicates, .. } => {
75                    let mut subqueries = vec![];
76                    for predicate in &mut *predicates {
77                        walk_scalar(predicate)?;
78                        extract_conjuncted_subqueries(predicate, &mut subqueries)?;
79                    }
80                    // TODO(benesch): we could be smarter about the order in which
81                    // we emit subqueries. At the moment we just emit in the order
82                    // we discovered them, but ideally we'd emit them in an order
83                    // that accounted for their cost/selectivity. E.g., low-cost,
84                    // high-selectivity subqueries should go first.
85                    for subquery in subqueries {
86                        predicates.push(subquery);
87                    }
88                }
89                _ => (),
90            }
91            Ok(())
92        })
93    }
94
95    fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
96        expr.try_visit_mut_post(&mut |expr| {
97            match expr {
98                HirScalarExpr::Exists(input) | HirScalarExpr::Select(input) => {
99                    walk_relation(input)?
100                }
101                _ => (),
102            }
103            Ok(())
104        })
105    }
106
107    fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
108        let mut found = false;
109        expr.visit_pre(&mut |expr| match expr {
110            HirScalarExpr::Exists(_) | HirScalarExpr::Select(_) => found = true,
111            _ => (),
112        })?;
113        Ok(found)
114    }
115
116    /// Extracts subqueries from a conjunction into `out`.
117    ///
118    /// For example, given an expression like
119    ///
120    /// ```text
121    /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
122    /// ```
123    ///
124    /// this function rewrites the expression to
125    ///
126    /// ```text
127    /// a = b AND true AND c = d AND true
128    /// ```
129    ///
130    /// and returns the expression fragments `EXISTS (<subquery 1>)` and
131    /// `(<subquery 2>) = e` in the `out` vector.
132    fn extract_conjuncted_subqueries(
133        expr: &mut HirScalarExpr,
134        out: &mut Vec<HirScalarExpr>,
135    ) -> Result<(), RecursionLimitError> {
136        match expr {
137            HirScalarExpr::CallVariadic {
138                func: VariadicFunc::And,
139                exprs,
140            } => {
141                exprs
142                    .into_iter()
143                    .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
144            }
145            expr if contains_subquery(expr)? => {
146                out.push(mem::replace(expr, HirScalarExpr::literal_true()))
147            }
148            _ => (),
149        }
150        Ok(())
151    }
152
153    walk_relation(expr)
154}
155
156/// Rewrites quantified comparisons into simpler EXISTS operators.
157///
158/// Note that this transformation is only valid when the expression is
159/// used in a context where the distinction between `FALSE` and `NULL`
160/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
161/// when the inputs to the comparison are non-nullable. This function is careful
162/// to only apply the transformation when it is valid to do so.
163///
164/// ```ignore
165/// WHERE (SELECT any(<pred>) FROM <rel>)
166/// =>
167/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
168///
169/// WHERE (SELECT all(<pred>) FROM <rel>)
170/// =>
171/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
172/// ```
173///
174/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
175/// M. Elhemali, et al.
176pub fn try_simplify_quantified_comparisons(
177    expr: &mut HirRelationExpr,
178) -> Result<(), RecursionLimitError> {
179    fn walk_relation(
180        expr: &mut HirRelationExpr,
181        outers: &[RelationType],
182    ) -> Result<(), RecursionLimitError> {
183        match expr {
184            HirRelationExpr::Map { scalars, input } => {
185                walk_relation(input, outers)?;
186                let mut outers = outers.to_vec();
187                outers.insert(0, input.typ(&outers, &NO_PARAMS));
188                for scalar in scalars {
189                    walk_scalar(scalar, &outers, false)?;
190                    let (inner, outers) = outers
191                        .split_first_mut()
192                        .expect("outers known to have at least one element");
193                    let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
194                    inner.column_types.push(scalar_type);
195                }
196            }
197            HirRelationExpr::Filter { predicates, input } => {
198                walk_relation(input, outers)?;
199                let mut outers = outers.to_vec();
200                outers.insert(0, input.typ(&outers, &NO_PARAMS));
201                for pred in predicates {
202                    walk_scalar(pred, &outers, true)?;
203                }
204            }
205            HirRelationExpr::CallTable { exprs, .. } => {
206                let mut outers = outers.to_vec();
207                outers.insert(0, RelationType::empty());
208                for scalar in exprs {
209                    walk_scalar(scalar, &outers, false)?;
210                }
211            }
212            HirRelationExpr::Join { left, right, .. } => {
213                walk_relation(left, outers)?;
214                let mut outers = outers.to_vec();
215                outers.insert(0, left.typ(&outers, &NO_PARAMS));
216                walk_relation(right, &outers)?;
217            }
218            expr => {
219                #[allow(deprecated)]
220                let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
221                    walk_relation(expr, outers)
222                });
223            }
224        }
225        Ok(())
226    }
227
228    fn walk_scalar(
229        expr: &mut HirScalarExpr,
230        outers: &[RelationType],
231        mut in_filter: bool,
232    ) -> Result<(), RecursionLimitError> {
233        expr.try_visit_mut_pre(&mut |e| {
234            match e {
235                HirScalarExpr::Exists(input) => walk_relation(input, outers)?,
236                HirScalarExpr::Select(input) => {
237                    walk_relation(input, outers)?;
238
239                    // We're inside a `(SELECT ...)` subquery. Now let's see if
240                    // it has the form `(SELECT <any|all>(...) FROM <input>)`.
241                    // Ideally we could do this with one pattern, but Rust's pattern
242                    // matching engine is not powerful enough, so we have to do this
243                    // in stages; the early returns avoid brutal nesting.
244
245                    let (func, expr, input) = match &mut **input {
246                        HirRelationExpr::Reduce {
247                            group_key,
248                            aggregates,
249                            input,
250                            expected_group_size: _,
251                        } if group_key.is_empty() && aggregates.len() == 1 => {
252                            let agg = &mut aggregates[0];
253                            (&agg.func, &mut agg.expr, input)
254                        }
255                        _ => return Ok(()),
256                    };
257
258                    if !in_filter && column_type(outers, input, expr).nullable {
259                        // Unless we're directly inside a WHERE, this
260                        // transformation is only valid if the expression involved
261                        // is non-nullable.
262                        return Ok(());
263                    }
264
265                    match func {
266                        AggregateFunc::Any => {
267                            // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
268                            // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
269                            *e = input.take().filter(vec![expr.take()]).exists();
270                        }
271                        AggregateFunc::All => {
272                            // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
273                            // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
274                            //
275                            // Note that negation of <expr> alone is insufficient.
276                            // Consider that `WHERE <pred>` filters out rows if
277                            // `<pred>` is false *or* null. To invert the test, we
278                            // need `NOT <pred> OR <pred> IS NULL`.
279                            let expr = expr.take();
280                            let filter = expr.clone().not().or(expr.call_is_null());
281                            *e = input.take().filter(vec![filter]).exists().not();
282                        }
283                        _ => (),
284                    }
285                }
286                _ => {
287                    // As soon as we see *any* scalar expression, we are no longer
288                    // directly inside a filter.
289                    in_filter = false;
290                }
291            }
292            Ok(())
293        })
294    }
295
296    walk_relation(expr, &[])
297}
298
299/// An empty parameter type map.
300///
301/// These transformations are expected to run after parameters are bound, so
302/// there is no need to provide any parameter type information.
303static NO_PARAMS: LazyLock<BTreeMap<usize, ScalarType>> = LazyLock::new(BTreeMap::new);
304
305fn column_type(
306    outers: &[RelationType],
307    inner: &HirRelationExpr,
308    expr: &HirScalarExpr,
309) -> ColumnType {
310    let inner_type = inner.typ(outers, &NO_PARAMS);
311    expr.typ(outers, &inner_type, &NO_PARAMS)
312}
313
314impl HirScalarExpr {
315    /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
316    /// considers column references that target the root level.
317    /// (See `visit_columns_referring_to_root_level`.)
318    fn support(&self) -> Vec<usize> {
319        let mut result = Vec::new();
320        self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
321        result
322    }
323
324    /// Changes column references in `self` by the given remapping.
325    /// Panics if a referred column is not present in `idx_map`!
326    fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
327        self.visit_columns_referring_to_root_level_mut(&mut |c| {
328            *c = idx_map[c];
329        });
330        self
331    }
332}
333
334/// # Aims and scope
335///
336/// The aim here is to amortize the overhead of the MIR window function pattern
337/// (see `window_func_applied_to`) by fusing groups of window function calls such
338/// that each group can be performed by one instance of the window function MIR
339/// pattern.
340///
341/// For now, we fuse only value window function calls and window aggregations.
342/// (We probably won't need to fuse scalar window functions for a long time.)
343///
344/// For now, we can fuse value window function calls and window aggregations where the
345/// A. partition by
346/// B. order by
347/// C. window frame
348/// D. ignore nulls for value window functions and distinct for window aggregations
349/// are all the same. (See `extract_options`.)
350/// (Later, we could improve this to only need A. to be the same. This would require
351/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
352/// TODO: As a much simpler intermediate step, at least we should ignore options that
353/// don't matter. For example, we should be able to fuse a `lag` that has a default
354/// frame with a `first_value` that has some custom frame, because `lag` is not
355/// affected by the frame.)
356/// Note that we fuse value window function calls and window aggregations separately.
357///
358/// # Implementation
359///
360/// At a high level, what we are going to do is look for Maps with more than one window function
361/// calls, and for each Map
362/// - remove some groups of window function call expressions from the Map's `scalars`;
363/// - insert a fused version of each group;
364/// - insert some expressions that decompose the results of the fused calls;
365/// - update some column references in `scalars`: those that refer to window function results that
366///   participated in fusion, as well as those that refer to columns that moved around due to
367///   removing and inserting expressions.
368/// - insert a Project above the matched Map to permute columns back to their original places.
369///
370/// It would be tempting to find groups simply by taking a list of all window function calls
371/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
372/// but a complication is that the possible groups that we could theoretically fuse overlap.
373/// This is because when forming groups we need to also take into account column references
374/// that point inside the same Map. For example, imagine a Map with the following scalar
375/// expressions:
376/// C1, E1, C2, C3, where
377/// - E1 refers to C1
378/// - C3 refers to E1.
379/// In this situation, we could either
380/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
381///   to it);
382/// - or fuse C2 and C3.
383/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
384/// no appropriate place for the fused expression: it would have to be both before and after E1.
385///
386/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
387/// we go through `scalars`, try to put each expression into each of our groups, and the first of
388/// these succeed. When trying to put an expression into a group, we need to be mindful about column
389/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
390/// sanity is that the fused version of each group will be inserted at the place where the first
391/// element of the group originally was. This means that the only condition that we need to check on
392/// column references when adding an expression to a group is that all column references in a group
393/// should be to columns that are earlier than the first element of the group. (No need to check
394/// column references in the other direction, i.e., references in other expressions that refer to
395/// columns in the group.)
396pub fn fuse_window_functions(
397    root: &mut HirRelationExpr,
398    _context: &crate::plan::lowering::Context,
399) -> Result<(), RecursionLimitError> {
400    /// Those options of a window function call that are relevant for fusion.
401    #[derive(PartialEq, Eq)]
402    enum WindowFuncCallOptions {
403        Value(ValueWindowFuncCallOptions),
404        Agg(AggregateWindowFuncCallOptions),
405    }
406    #[derive(PartialEq, Eq)]
407    struct ValueWindowFuncCallOptions {
408        partition_by: Vec<HirScalarExpr>,
409        outer_order_by: Vec<HirScalarExpr>,
410        inner_order_by: Vec<ColumnOrder>,
411        window_frame: WindowFrame,
412        ignore_nulls: bool,
413    }
414    #[derive(PartialEq, Eq)]
415    struct AggregateWindowFuncCallOptions {
416        partition_by: Vec<HirScalarExpr>,
417        outer_order_by: Vec<HirScalarExpr>,
418        inner_order_by: Vec<ColumnOrder>,
419        window_frame: WindowFrame,
420        distinct: bool,
421    }
422
423    /// Helper function to extract the above options.
424    fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
425        match call {
426            HirScalarExpr::Windowing(WindowExpr {
427                func:
428                    WindowExprType::Value(ValueWindowExpr {
429                        order_by: inner_order_by,
430                        window_frame,
431                        ignore_nulls,
432                        func: _,
433                        args: _,
434                    }),
435                partition_by,
436                order_by: outer_order_by,
437            }) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
438                partition_by: partition_by.clone(),
439                outer_order_by: outer_order_by.clone(),
440                inner_order_by: inner_order_by.clone(),
441                window_frame: window_frame.clone(),
442                ignore_nulls: ignore_nulls.clone(),
443            }),
444            HirScalarExpr::Windowing(WindowExpr {
445                func:
446                    WindowExprType::Aggregate(AggregateWindowExpr {
447                        aggregate_expr:
448                            AggregateExpr {
449                                distinct,
450                                func: _,
451                                expr: _,
452                            },
453                        order_by: inner_order_by,
454                        window_frame,
455                    }),
456                partition_by,
457                order_by: outer_order_by,
458            }) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
459                partition_by: partition_by.clone(),
460                outer_order_by: outer_order_by.clone(),
461                inner_order_by: inner_order_by.clone(),
462                window_frame: window_frame.clone(),
463                distinct: distinct.clone(),
464            }),
465            _ => panic!(
466                "extract_options should only be called on value window functions or window aggregations"
467            ),
468        }
469    }
470
471    struct FusionGroup {
472        /// The original column index of the first element of the group. (This is an index into the
473        /// Map's `scalars` plus the arity of the Map's input.)
474        first_col: usize,
475        /// The options of all the window function calls in the group. (Must be the same for all the
476        /// calls.)
477        options: WindowFuncCallOptions,
478        /// The calls in the group, with their original column indexes.
479        calls: Vec<(usize, HirScalarExpr)>,
480    }
481
482    impl FusionGroup {
483        /// Creates a window function call that is a fused version of all the calls in the group.
484        /// `new_col` is the column index where the fused call will be inserted at.
485        fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
486            let fused = match self.options {
487                WindowFuncCallOptions::Value(options) => {
488                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
489                        .calls
490                        .iter()
491                        .map(|(_idx, call)| {
492                            if let HirScalarExpr::Windowing(WindowExpr {
493                                func:
494                                    WindowExprType::Value(ValueWindowExpr {
495                                        func,
496                                        args,
497                                        order_by: _,
498                                        window_frame: _,
499                                        ignore_nulls: _,
500                                    }),
501                                partition_by: _,
502                                order_by: _,
503                            }) = call
504                            {
505                                (func.clone(), (**args).clone())
506                            } else {
507                                panic!("unknown window function in FusionGroup")
508                            }
509                        })
510                        .unzip();
511                    let fused_args = HirScalarExpr::CallVariadic {
512                        func: VariadicFunc::RecordCreate {
513                            // These field names are not important, because this record will only be an
514                            // intermediate expression, which we'll manipulate further before it ends up
515                            // anywhere where a column name would be visible.
516                            field_names: iter::repeat(ColumnName::from(""))
517                                .take(fused_args.len())
518                                .collect(),
519                        },
520                        exprs: fused_args,
521                    };
522                    HirScalarExpr::Windowing(WindowExpr {
523                        func: WindowExprType::Value(ValueWindowExpr {
524                            func: ValueWindowFunc::Fused(fused_funcs),
525                            args: Box::new(fused_args),
526                            order_by: options.inner_order_by,
527                            window_frame: options.window_frame,
528                            ignore_nulls: options.ignore_nulls,
529                        }),
530                        partition_by: options.partition_by,
531                        order_by: options.outer_order_by,
532                    })
533                }
534                WindowFuncCallOptions::Agg(options) => {
535                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
536                        .calls
537                        .iter()
538                        .map(|(_idx, call)| {
539                            if let HirScalarExpr::Windowing(WindowExpr {
540                                func:
541                                    WindowExprType::Aggregate(AggregateWindowExpr {
542                                        aggregate_expr:
543                                            AggregateExpr {
544                                                func,
545                                                expr,
546                                                distinct: _,
547                                            },
548                                        order_by: _,
549                                        window_frame: _,
550                                    }),
551                                partition_by: _,
552                                order_by: _,
553                            }) = call
554                            {
555                                (func.clone(), (**expr).clone())
556                            } else {
557                                panic!("unknown window function in FusionGroup")
558                            }
559                        })
560                        .unzip();
561                    let fused_args = HirScalarExpr::CallVariadic {
562                        func: VariadicFunc::RecordCreate {
563                            field_names: iter::repeat(ColumnName::from(""))
564                                .take(fused_args.len())
565                                .collect(),
566                        },
567                        exprs: fused_args,
568                    };
569                    HirScalarExpr::Windowing(WindowExpr {
570                        func: WindowExprType::Aggregate(AggregateWindowExpr {
571                            aggregate_expr: AggregateExpr {
572                                func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
573                                expr: Box::new(fused_args),
574                                distinct: options.distinct,
575                            },
576                            order_by: options.inner_order_by,
577                            window_frame: options.window_frame,
578                        }),
579                        partition_by: options.partition_by,
580                        order_by: options.outer_order_by,
581                    })
582                }
583            };
584
585            let decompositions = (0..self.calls.len())
586                .map(|field| HirScalarExpr::CallUnary {
587                    func: UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)),
588                    expr: Box::new(HirScalarExpr::Column(ColumnRef {
589                        level: 0,
590                        column: new_col,
591                    })),
592                })
593                .collect();
594
595            (fused, decompositions)
596        }
597    }
598
599    let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
600        // Look for calls only at the root of scalar expressions. This is enough
601        // because they are always there, see 72e84bb78.
602        match scalar_expr {
603            HirScalarExpr::Windowing(WindowExpr {
604                func: WindowExprType::Value(ValueWindowExpr { func, .. }),
605                ..
606            }) => {
607                // Exclude those calls that are already fused. (We shouldn't currently
608                // encounter these, because we just do one pass, but it's better to be
609                // robust against future code changes.)
610                !matches!(func, ValueWindowFunc::Fused(..))
611            }
612            HirScalarExpr::Windowing(WindowExpr {
613                func:
614                    WindowExprType::Aggregate(AggregateWindowExpr {
615                        aggregate_expr: AggregateExpr { func, .. },
616                        ..
617                    }),
618                ..
619            }) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
620            _ => false,
621        }
622    };
623
624    root.try_visit_mut_post(&mut |rel_expr| {
625        match rel_expr {
626            HirRelationExpr::Map { input, scalars } => {
627                // There will be various variable names involving `idx` or `col`:
628                // - `idx` will always be an index into `scalars` or something similar,
629                // - `col` will always be a column index,
630                //   which is often `arity_before_map` + an index into `scalars`.
631                let arity_before_map = input.arity();
632                let orig_num_scalars = scalars.len();
633
634                // Collect all value window function calls and window aggregations with their column
635                // indexes.
636                let value_or_agg_window_func_calls = scalars
637                    .iter()
638                    .enumerate()
639                    .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
640                    .map(|(idx, call)| (idx + arity_before_map, call.clone()))
641                    .collect_vec();
642                // Exit early if obviously no chance for fusion.
643                if value_or_agg_window_func_calls.len() <= 1 {
644                    // Note that we are doing this only for performance. All plans should be exactly
645                    // the same even if we comment out the following line.
646                    return Ok(());
647                }
648
649                // Determine the fusion groups. (Each group will later be fused into one window
650                // function call.)
651                // Note that this has a quadratic run time with value_or_agg_window_func_calls in
652                // the worst case. However, this is fine even with 1000 window function calls.
653                let mut groups: Vec<FusionGroup> = Vec::new();
654                for (col, call) in value_or_agg_window_func_calls {
655                    let options = extract_options(&call);
656                    let support = call.support();
657                    let to_fuse_with = groups
658                        .iter_mut()
659                        .filter(|group| {
660                            group.options == options && support.iter().all(|c| *c < group.first_col)
661                        })
662                        .next();
663                    if let Some(group) = to_fuse_with {
664                        group.calls.push((col, call.clone()));
665                    } else {
666                        groups.push(FusionGroup {
667                            first_col: col,
668                            options,
669                            calls: vec![(col, call.clone())],
670                        });
671                    }
672                }
673
674                // No fusion to do on groups of 1.
675                groups.retain(|g| g.calls.len() > 1);
676
677                let removals: BTreeSet<usize> = groups
678                    .iter()
679                    .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
680                    .collect();
681
682                // Mutate `scalars`.
683                // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
684                // `groups` is already sorted by `first_col` due to the way it was constructed.)
685                // We also compute a remapping of old indexes to new indexes as we go.
686                let mut groups_it = groups.drain(..).peekable();
687                let mut group = groups_it.next();
688                let mut remap = BTreeMap::new();
689                remap.extend((0..arity_before_map).map(|col| (col, col)));
690                let mut new_col: usize = arity_before_map;
691                let mut new_scalars = Vec::new();
692                for (old_col, e) in scalars
693                    .drain(..)
694                    .enumerate()
695                    .map(|(idx, e)| (idx + arity_before_map, e))
696                {
697                    if group.as_ref().is_some_and(|g| g.first_col == old_col) {
698                        // The current expression will be fused away, and a fused expression will
699                        // appear in its place. Additionally, some new expressions will be inserted
700                        // after the fused expression, to decompose the record that is the result of
701                        // the fused call.
702                        assert!(removals.contains(&old_col));
703                        let group_unwrapped = group.expect("checked above");
704                        let calls_cols = group_unwrapped
705                            .calls
706                            .iter()
707                            .map(|(col, _call)| *col)
708                            .collect_vec();
709                        let (fused, decompositions) = group_unwrapped.fuse(new_col);
710                        new_scalars.push(fused.remap(&remap));
711                        new_scalars.extend(decompositions); // (no remapping needed)
712                        new_col += 1;
713                        for call_old_col in calls_cols {
714                            let present = remap.insert(call_old_col, new_col);
715                            assert!(present.is_none());
716                            new_col += 1;
717                        }
718                        group = groups_it.next();
719                    } else if removals.contains(&old_col) {
720                        assert!(remap.contains_key(&old_col));
721                    } else {
722                        new_scalars.push(e.remap(&remap));
723                        let present = remap.insert(old_col, new_col);
724                        assert!(present.is_none());
725                        new_col += 1;
726                    }
727                }
728                *scalars = new_scalars;
729                assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
730
731                // Add a project to permute columns back to their original places.
732                *rel_expr = rel_expr.take().project(
733                    (0..arity_before_map)
734                        .chain((0..orig_num_scalars).map(|idx| {
735                            *remap
736                                .get(&(idx + arity_before_map))
737                                .expect("all columns should be present by now")
738                        }))
739                        .collect(),
740                );
741
742                assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
743            }
744            _ => {}
745        }
746        Ok(())
747    })
748}