Skip to main content

mz_sql/plan/
transform_hir.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::func::variadic::RecordCreate;
19use mz_expr::visit::Visit;
20use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
21use mz_ore::stack::RecursionLimitError;
22use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
23
24use crate::plan::hir::{
25    AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
26    ValueWindowExpr, ValueWindowFunc, WindowExpr,
27};
28use crate::plan::{AggregateExpr, WindowExprType};
29
30/// Rewrites predicates that contain subqueries so that the subqueries
31/// appear in their own later predicate when possible.
32///
33/// For example, this function rewrites this expression
34///
35/// ```text
36/// Filter {
37///     predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
38/// }
39/// ```
40///
41/// like so:
42///
43/// ```text
44/// Filter {
45///     predicates: [
46///         a = b AND c = d,
47///         EXISTS (<subquery>),
48///         (<subquery 2>) = e,
49///     ]
50/// }
51/// ```
52///
53/// The rewrite causes decorrelation to incorporate prior predicates into
54/// the outer relation upon which the subquery is evaluated. In the above
55/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
56/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
57/// = e`, will be further restricted to outer rows that match `A = b AND c =
58/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
59/// subquery, especially when the original conjunction contains join keys.
60pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61    fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
62        #[allow(deprecated)]
63        expr.visit_mut_fallible(0, &mut |expr, _| {
64            match expr {
65                HirRelationExpr::Map { scalars, .. } => {
66                    for scalar in scalars {
67                        walk_scalar(scalar)?;
68                    }
69                }
70                HirRelationExpr::CallTable { exprs, .. } => {
71                    for expr in exprs {
72                        walk_scalar(expr)?;
73                    }
74                }
75                HirRelationExpr::Filter { predicates, .. } => {
76                    let mut subqueries = vec![];
77                    for predicate in &mut *predicates {
78                        walk_scalar(predicate)?;
79                        extract_conjuncted_subqueries(predicate, &mut subqueries)?;
80                    }
81                    // TODO(benesch): we could be smarter about the order in which
82                    // we emit subqueries. At the moment we just emit in the order
83                    // we discovered them, but ideally we'd emit them in an order
84                    // that accounted for their cost/selectivity. E.g., low-cost,
85                    // high-selectivity subqueries should go first.
86                    for subquery in subqueries {
87                        predicates.push(subquery);
88                    }
89                }
90                _ => (),
91            }
92            Ok(())
93        })
94    }
95
96    fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
97        expr.try_visit_mut_post(&mut |expr| {
98            match expr {
99                HirScalarExpr::Exists(input, _name) | HirScalarExpr::Select(input, _name) => {
100                    walk_relation(input)?
101                }
102                _ => (),
103            }
104            Ok(())
105        })
106    }
107
108    fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
109        let mut found = false;
110        expr.visit_pre(&mut |expr| match expr {
111            HirScalarExpr::Exists(..) | HirScalarExpr::Select(..) => found = true,
112            _ => (),
113        })?;
114        Ok(found)
115    }
116
117    /// Extracts subqueries from a conjunction into `out`.
118    ///
119    /// For example, given an expression like
120    ///
121    /// ```text
122    /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
123    /// ```
124    ///
125    /// this function rewrites the expression to
126    ///
127    /// ```text
128    /// a = b AND true AND c = d AND true
129    /// ```
130    ///
131    /// and returns the expression fragments `EXISTS (<subquery 1>)` and
132    /// `(<subquery 2>) = e` in the `out` vector.
133    fn extract_conjuncted_subqueries(
134        expr: &mut HirScalarExpr,
135        out: &mut Vec<HirScalarExpr>,
136    ) -> Result<(), RecursionLimitError> {
137        match expr {
138            HirScalarExpr::CallVariadic {
139                func: VariadicFunc::And(_),
140                exprs,
141                name: _,
142            } => {
143                exprs
144                    .into_iter()
145                    .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
146            }
147            expr if contains_subquery(expr)? => {
148                out.push(mem::replace(expr, HirScalarExpr::literal_true()))
149            }
150            _ => (),
151        }
152        Ok(())
153    }
154
155    walk_relation(expr)
156}
157
158/// Rewrites quantified comparisons into simpler EXISTS operators.
159///
160/// Note that this transformation is only valid when the expression is
161/// used in a context where the distinction between `FALSE` and `NULL`
162/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
163/// when the inputs to the comparison are non-nullable. This function is careful
164/// to only apply the transformation when it is valid to do so.
165///
166/// ```ignore
167/// WHERE (SELECT any(<pred>) FROM <rel>)
168/// =>
169/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
170///
171/// WHERE (SELECT all(<pred>) FROM <rel>)
172/// =>
173/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
174/// ```
175///
176/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
177/// M. Elhemali, et al.
178pub fn try_simplify_quantified_comparisons(
179    expr: &mut HirRelationExpr,
180) -> Result<(), RecursionLimitError> {
181    fn walk_relation(
182        expr: &mut HirRelationExpr,
183        outers: &[SqlRelationType],
184    ) -> Result<(), RecursionLimitError> {
185        match expr {
186            HirRelationExpr::Map { scalars, input } => {
187                walk_relation(input, outers)?;
188                let mut outers = outers.to_vec();
189                outers.insert(0, input.typ(&outers, &NO_PARAMS));
190                for scalar in scalars {
191                    walk_scalar(scalar, &outers, false)?;
192                    let (inner, outers) = outers
193                        .split_first_mut()
194                        .expect("outers known to have at least one element");
195                    let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
196                    inner.column_types.push(scalar_type);
197                }
198            }
199            HirRelationExpr::Filter { predicates, input } => {
200                walk_relation(input, outers)?;
201                let mut outers = outers.to_vec();
202                outers.insert(0, input.typ(&outers, &NO_PARAMS));
203                for pred in predicates {
204                    walk_scalar(pred, &outers, true)?;
205                }
206            }
207            HirRelationExpr::CallTable { exprs, .. } => {
208                let mut outers = outers.to_vec();
209                outers.insert(0, SqlRelationType::empty());
210                for scalar in exprs {
211                    walk_scalar(scalar, &outers, false)?;
212                }
213            }
214            HirRelationExpr::Join { left, right, .. } => {
215                walk_relation(left, outers)?;
216                let mut outers = outers.to_vec();
217                outers.insert(0, left.typ(&outers, &NO_PARAMS));
218                walk_relation(right, &outers)?;
219            }
220            expr => {
221                #[allow(deprecated)]
222                let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
223                    walk_relation(expr, outers)
224                });
225            }
226        }
227        Ok(())
228    }
229
230    fn walk_scalar(
231        expr: &mut HirScalarExpr,
232        outers: &[SqlRelationType],
233        mut in_filter: bool,
234    ) -> Result<(), RecursionLimitError> {
235        expr.try_visit_mut_pre(&mut |e| {
236            match e {
237                HirScalarExpr::Exists(input, _name) => walk_relation(input, outers)?,
238                HirScalarExpr::Select(input, _name) => {
239                    walk_relation(input, outers)?;
240
241                    // We're inside a `(SELECT ...)` subquery. Now let's see if
242                    // it has the form `(SELECT <any|all>(...) FROM <input>)`.
243                    // Ideally we could do this with one pattern, but Rust's pattern
244                    // matching engine is not powerful enough, so we have to do this
245                    // in stages; the early returns avoid brutal nesting.
246
247                    let (func, expr, input) = match &mut **input {
248                        HirRelationExpr::Reduce {
249                            group_key,
250                            aggregates,
251                            input,
252                            expected_group_size: _,
253                        } if group_key.is_empty() && aggregates.len() == 1 => {
254                            let agg = &mut aggregates[0];
255                            (&agg.func, &mut agg.expr, input)
256                        }
257                        _ => return Ok(()),
258                    };
259
260                    if !in_filter && column_type(outers, input, expr).nullable {
261                        // Unless we're directly inside a WHERE, this
262                        // transformation is only valid if the expression involved
263                        // is non-nullable.
264                        return Ok(());
265                    }
266
267                    match func {
268                        AggregateFunc::Any => {
269                            // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
270                            // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
271                            *e = input.take().filter(vec![expr.take()]).exists();
272                        }
273                        AggregateFunc::All => {
274                            // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
275                            // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
276                            //
277                            // Note that negation of <expr> alone is insufficient.
278                            // Consider that `WHERE <pred>` filters out rows if
279                            // `<pred>` is false *or* null. To invert the test, we
280                            // need `NOT <pred> OR <pred> IS NULL`.
281                            let expr = expr.take();
282                            let filter = expr.clone().not().or(expr.call_is_null());
283                            *e = input.take().filter(vec![filter]).exists().not();
284                        }
285                        _ => (),
286                    }
287                }
288                _ => {
289                    // As soon as we see *any* scalar expression, we are no longer
290                    // directly inside a filter.
291                    in_filter = false;
292                }
293            }
294            Ok(())
295        })
296    }
297
298    walk_relation(expr, &[])
299}
300
301/// An empty parameter type map.
302///
303/// These transformations are expected to run after parameters are bound, so
304/// there is no need to provide any parameter type information.
305static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
306
307fn column_type(
308    outers: &[SqlRelationType],
309    inner: &HirRelationExpr,
310    expr: &HirScalarExpr,
311) -> SqlColumnType {
312    let inner_type = inner.typ(outers, &NO_PARAMS);
313    expr.typ(outers, &inner_type, &NO_PARAMS)
314}
315
316impl HirScalarExpr {
317    /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
318    /// considers column references that target the root level.
319    /// (See `visit_columns_referring_to_root_level`.)
320    fn support(&self) -> Vec<usize> {
321        let mut result = Vec::new();
322        self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
323        result
324    }
325
326    /// Changes column references in `self` by the given remapping.
327    /// Panics if a referred column is not present in `idx_map`!
328    fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
329        self.visit_columns_referring_to_root_level_mut(&mut |c| {
330            *c = idx_map[c];
331        });
332        self
333    }
334}
335
336/// # Aims and scope
337///
338/// The aim here is to amortize the overhead of the MIR window function pattern
339/// (see `window_func_applied_to`) by fusing groups of window function calls such
340/// that each group can be performed by one instance of the window function MIR
341/// pattern.
342///
343/// For now, we fuse only value window function calls and window aggregations.
344/// (We probably won't need to fuse scalar window functions for a long time.)
345///
346/// For now, we can fuse value window function calls and window aggregations where the
347/// A. partition by
348/// B. order by
349/// C. window frame
350/// D. ignore nulls for value window functions and distinct for window aggregations
351/// are all the same. (See `extract_options`.)
352/// (Later, we could improve this to only need A. to be the same. This would require
353/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
354/// TODO: As a much simpler intermediate step, at least we should ignore options that
355/// don't matter. For example, we should be able to fuse a `lag` that has a default
356/// frame with a `first_value` that has some custom frame, because `lag` is not
357/// affected by the frame.)
358/// Note that we fuse value window function calls and window aggregations separately.
359///
360/// # Implementation
361///
362/// At a high level, what we are going to do is look for Maps with more than one window function
363/// calls, and for each Map
364/// - remove some groups of window function call expressions from the Map's `scalars`;
365/// - insert a fused version of each group;
366/// - insert some expressions that decompose the results of the fused calls;
367/// - update some column references in `scalars`: those that refer to window function results that
368///   participated in fusion, as well as those that refer to columns that moved around due to
369///   removing and inserting expressions.
370/// - insert a Project above the matched Map to permute columns back to their original places.
371///
372/// It would be tempting to find groups simply by taking a list of all window function calls
373/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
374/// but a complication is that the possible groups that we could theoretically fuse overlap.
375/// This is because when forming groups we need to also take into account column references
376/// that point inside the same Map. For example, imagine a Map with the following scalar
377/// expressions:
378/// C1, E1, C2, C3, where
379/// - E1 refers to C1
380/// - C3 refers to E1.
381/// In this situation, we could either
382/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
383///   to it);
384/// - or fuse C2 and C3.
385/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
386/// no appropriate place for the fused expression: it would have to be both before and after E1.
387///
388/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
389/// we go through `scalars`, try to put each expression into each of our groups, and the first of
390/// these succeed. When trying to put an expression into a group, we need to be mindful about column
391/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
392/// sanity is that the fused version of each group will be inserted at the place where the first
393/// element of the group originally was. This means that the only condition that we need to check on
394/// column references when adding an expression to a group is that all column references in a group
395/// should be to columns that are earlier than the first element of the group. (No need to check
396/// column references in the other direction, i.e., references in other expressions that refer to
397/// columns in the group.)
398pub fn fuse_window_functions(
399    root: &mut HirRelationExpr,
400    _context: &crate::plan::lowering::Context,
401) -> Result<(), RecursionLimitError> {
402    /// Those options of a window function call that are relevant for fusion.
403    #[derive(PartialEq, Eq)]
404    enum WindowFuncCallOptions {
405        Value(ValueWindowFuncCallOptions),
406        Agg(AggregateWindowFuncCallOptions),
407    }
408    #[derive(PartialEq, Eq)]
409    struct ValueWindowFuncCallOptions {
410        partition_by: Vec<HirScalarExpr>,
411        outer_order_by: Vec<HirScalarExpr>,
412        inner_order_by: Vec<ColumnOrder>,
413        window_frame: WindowFrame,
414        ignore_nulls: bool,
415    }
416    #[derive(PartialEq, Eq)]
417    struct AggregateWindowFuncCallOptions {
418        partition_by: Vec<HirScalarExpr>,
419        outer_order_by: Vec<HirScalarExpr>,
420        inner_order_by: Vec<ColumnOrder>,
421        window_frame: WindowFrame,
422        distinct: bool,
423    }
424
425    /// Helper function to extract the above options.
426    fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
427        match call {
428            HirScalarExpr::Windowing(
429                WindowExpr {
430                    func:
431                        WindowExprType::Value(ValueWindowExpr {
432                            order_by: inner_order_by,
433                            window_frame,
434                            ignore_nulls,
435                            func: _,
436                            args: _,
437                        }),
438                    partition_by,
439                    order_by: outer_order_by,
440                },
441                _name,
442            ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
443                partition_by: partition_by.clone(),
444                outer_order_by: outer_order_by.clone(),
445                inner_order_by: inner_order_by.clone(),
446                window_frame: window_frame.clone(),
447                ignore_nulls: ignore_nulls.clone(),
448            }),
449            HirScalarExpr::Windowing(
450                WindowExpr {
451                    func:
452                        WindowExprType::Aggregate(AggregateWindowExpr {
453                            aggregate_expr:
454                                AggregateExpr {
455                                    distinct,
456                                    func: _,
457                                    expr: _,
458                                },
459                            order_by: inner_order_by,
460                            window_frame,
461                        }),
462                    partition_by,
463                    order_by: outer_order_by,
464                },
465                _name,
466            ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
467                partition_by: partition_by.clone(),
468                outer_order_by: outer_order_by.clone(),
469                inner_order_by: inner_order_by.clone(),
470                window_frame: window_frame.clone(),
471                distinct: distinct.clone(),
472            }),
473            _ => panic!(
474                "extract_options should only be called on value window functions or window aggregations"
475            ),
476        }
477    }
478
479    struct FusionGroup {
480        /// The original column index of the first element of the group. (This is an index into the
481        /// Map's `scalars` plus the arity of the Map's input.)
482        first_col: usize,
483        /// The options of all the window function calls in the group. (Must be the same for all the
484        /// calls.)
485        options: WindowFuncCallOptions,
486        /// The calls in the group, with their original column indexes.
487        calls: Vec<(usize, HirScalarExpr)>,
488    }
489
490    impl FusionGroup {
491        /// Creates a window function call that is a fused version of all the calls in the group.
492        /// `new_col` is the column index where the fused call will be inserted at.
493        fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
494            let fused = match self.options {
495                WindowFuncCallOptions::Value(options) => {
496                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
497                        .calls
498                        .iter()
499                        .map(|(_idx, call)| {
500                            if let HirScalarExpr::Windowing(
501                                WindowExpr {
502                                    func:
503                                        WindowExprType::Value(ValueWindowExpr {
504                                            func,
505                                            args,
506                                            order_by: _,
507                                            window_frame: _,
508                                            ignore_nulls: _,
509                                        }),
510                                    partition_by: _,
511                                    order_by: _,
512                                },
513                                _name,
514                            ) = call
515                            {
516                                (func.clone(), (**args).clone())
517                            } else {
518                                panic!("unknown window function in FusionGroup")
519                            }
520                        })
521                        .unzip();
522                    let fused_args = HirScalarExpr::call_variadic(
523                        RecordCreate {
524                            // These field names are not important, because this record will only be an
525                            // intermediate expression, which we'll manipulate further before it ends up
526                            // anywhere where a column name would be visible.
527                            field_names: iter::repeat(ColumnName::from(""))
528                                .take(fused_args.len())
529                                .collect(),
530                        },
531                        fused_args,
532                    );
533                    HirScalarExpr::windowing(WindowExpr {
534                        func: WindowExprType::Value(ValueWindowExpr {
535                            func: ValueWindowFunc::Fused(fused_funcs),
536                            args: Box::new(fused_args),
537                            order_by: options.inner_order_by,
538                            window_frame: options.window_frame,
539                            ignore_nulls: options.ignore_nulls,
540                        }),
541                        partition_by: options.partition_by,
542                        order_by: options.outer_order_by,
543                    })
544                }
545                WindowFuncCallOptions::Agg(options) => {
546                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
547                        .calls
548                        .iter()
549                        .map(|(_idx, call)| {
550                            if let HirScalarExpr::Windowing(
551                                WindowExpr {
552                                    func:
553                                        WindowExprType::Aggregate(AggregateWindowExpr {
554                                            aggregate_expr:
555                                                AggregateExpr {
556                                                    func,
557                                                    expr,
558                                                    distinct: _,
559                                                },
560                                            order_by: _,
561                                            window_frame: _,
562                                        }),
563                                    partition_by: _,
564                                    order_by: _,
565                                },
566                                _name,
567                            ) = call
568                            {
569                                (func.clone(), (**expr).clone())
570                            } else {
571                                panic!("unknown window function in FusionGroup")
572                            }
573                        })
574                        .unzip();
575                    let fused_args = HirScalarExpr::call_variadic(
576                        RecordCreate {
577                            field_names: iter::repeat(ColumnName::from(""))
578                                .take(fused_args.len())
579                                .collect(),
580                        },
581                        fused_args,
582                    );
583                    HirScalarExpr::windowing(WindowExpr {
584                        func: WindowExprType::Aggregate(AggregateWindowExpr {
585                            aggregate_expr: AggregateExpr {
586                                func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
587                                expr: Box::new(fused_args),
588                                distinct: options.distinct,
589                            },
590                            order_by: options.inner_order_by,
591                            window_frame: options.window_frame,
592                        }),
593                        partition_by: options.partition_by,
594                        order_by: options.outer_order_by,
595                    })
596                }
597            };
598
599            let decompositions = (0..self.calls.len())
600                .map(|field| {
601                    HirScalarExpr::column(new_col)
602                        .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
603                })
604                .collect();
605
606            (fused, decompositions)
607        }
608    }
609
610    let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
611        // Look for calls only at the root of scalar expressions. This is enough
612        // because they are always there, see 72e84bb78.
613        match scalar_expr {
614            HirScalarExpr::Windowing(
615                WindowExpr {
616                    func: WindowExprType::Value(ValueWindowExpr { func, .. }),
617                    ..
618                },
619                _name,
620            ) => {
621                // Exclude those calls that are already fused. (We shouldn't currently
622                // encounter these, because we just do one pass, but it's better to be
623                // robust against future code changes.)
624                !matches!(func, ValueWindowFunc::Fused(..))
625            }
626            HirScalarExpr::Windowing(
627                WindowExpr {
628                    func:
629                        WindowExprType::Aggregate(AggregateWindowExpr {
630                            aggregate_expr: AggregateExpr { func, .. },
631                            ..
632                        }),
633                    ..
634                },
635                _name,
636            ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
637            _ => false,
638        }
639    };
640
641    root.try_visit_mut_post(&mut |rel_expr| {
642        match rel_expr {
643            HirRelationExpr::Map { input, scalars } => {
644                // There will be various variable names involving `idx` or `col`:
645                // - `idx` will always be an index into `scalars` or something similar,
646                // - `col` will always be a column index,
647                //   which is often `arity_before_map` + an index into `scalars`.
648                let arity_before_map = input.arity();
649                let orig_num_scalars = scalars.len();
650
651                // Collect all value window function calls and window aggregations with their column
652                // indexes.
653                let value_or_agg_window_func_calls = scalars
654                    .iter()
655                    .enumerate()
656                    .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
657                    .map(|(idx, call)| (idx + arity_before_map, call.clone()))
658                    .collect_vec();
659                // Exit early if obviously no chance for fusion.
660                if value_or_agg_window_func_calls.len() <= 1 {
661                    // Note that we are doing this only for performance. All plans should be exactly
662                    // the same even if we comment out the following line.
663                    return Ok(());
664                }
665
666                // Determine the fusion groups. (Each group will later be fused into one window
667                // function call.)
668                // Note that this has a quadratic run time with value_or_agg_window_func_calls in
669                // the worst case. However, this is fine even with 1000 window function calls.
670                let mut groups: Vec<FusionGroup> = Vec::new();
671                for (col, call) in value_or_agg_window_func_calls {
672                    let options = extract_options(&call);
673                    let support = call.support();
674                    let to_fuse_with = groups
675                        .iter_mut()
676                        .filter(|group| {
677                            group.options == options && support.iter().all(|c| *c < group.first_col)
678                        })
679                        .next();
680                    if let Some(group) = to_fuse_with {
681                        group.calls.push((col, call.clone()));
682                    } else {
683                        groups.push(FusionGroup {
684                            first_col: col,
685                            options,
686                            calls: vec![(col, call.clone())],
687                        });
688                    }
689                }
690
691                // No fusion to do on groups of 1.
692                groups.retain(|g| g.calls.len() > 1);
693
694                let removals: BTreeSet<usize> = groups
695                    .iter()
696                    .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
697                    .collect();
698
699                // Mutate `scalars`.
700                // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
701                // `groups` is already sorted by `first_col` due to the way it was constructed.)
702                // We also compute a remapping of old indexes to new indexes as we go.
703                let mut groups_it = groups.drain(..).peekable();
704                let mut group = groups_it.next();
705                let mut remap = BTreeMap::new();
706                remap.extend((0..arity_before_map).map(|col| (col, col)));
707                let mut new_col: usize = arity_before_map;
708                let mut new_scalars = Vec::new();
709                for (old_col, e) in scalars
710                    .drain(..)
711                    .enumerate()
712                    .map(|(idx, e)| (idx + arity_before_map, e))
713                {
714                    if group.as_ref().is_some_and(|g| g.first_col == old_col) {
715                        // The current expression will be fused away, and a fused expression will
716                        // appear in its place. Additionally, some new expressions will be inserted
717                        // after the fused expression, to decompose the record that is the result of
718                        // the fused call.
719                        assert!(removals.contains(&old_col));
720                        let group_unwrapped = group.expect("checked above");
721                        let calls_cols = group_unwrapped
722                            .calls
723                            .iter()
724                            .map(|(col, _call)| *col)
725                            .collect_vec();
726                        let (fused, decompositions) = group_unwrapped.fuse(new_col);
727                        new_scalars.push(fused.remap(&remap));
728                        new_scalars.extend(decompositions); // (no remapping needed)
729                        new_col += 1;
730                        for call_old_col in calls_cols {
731                            let present = remap.insert(call_old_col, new_col);
732                            assert!(present.is_none());
733                            new_col += 1;
734                        }
735                        group = groups_it.next();
736                    } else if removals.contains(&old_col) {
737                        assert!(remap.contains_key(&old_col));
738                    } else {
739                        new_scalars.push(e.remap(&remap));
740                        let present = remap.insert(old_col, new_col);
741                        assert!(present.is_none());
742                        new_col += 1;
743                    }
744                }
745                *scalars = new_scalars;
746                assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
747
748                // Add a project to permute columns back to their original places.
749                *rel_expr = rel_expr.take().project(
750                    (0..arity_before_map)
751                        .chain((0..orig_num_scalars).map(|idx| {
752                            *remap
753                                .get(&(idx + arity_before_map))
754                                .expect("all columns should be present by now")
755                        }))
756                        .collect(),
757                );
758
759                assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
760            }
761            _ => {}
762        }
763        Ok(())
764    })
765}