1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! Transformations of SQL IR, before decorrelation.

use std::collections::{BTreeMap, BTreeSet};
use std::sync::LazyLock;
use std::{iter, mem};

use itertools::Itertools;
use mz_expr::visit::Visit;
use mz_expr::WindowFrame;
use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
use mz_ore::stack::RecursionLimitError;
use mz_repr::{ColumnName, ColumnType, RelationType, ScalarType};

use crate::plan::expr::{
    AbstractExpr, AggregateFunc, AggregateWindowExpr, ColumnRef, HirRelationExpr, HirScalarExpr,
    ValueWindowExpr, ValueWindowFunc, WindowExpr,
};
use crate::plan::{AggregateExpr, WindowExprType};

/// Rewrites predicates that contain subqueries so that the subqueries
/// appear in their own later predicate when possible.
///
/// For example, this function rewrites this expression
///
/// ```text
/// Filter {
///     predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
/// }
/// ```
///
/// like so:
///
/// ```text
/// Filter {
///     predicates: [
///         a = b AND c = d,
///         EXISTS (<subquery>),
///         (<subquery 2>) = e,
///     ]
/// }
/// ```
///
/// The rewrite causes decorrelation to incorporate prior predicates into
/// the outer relation upon which the subquery is evaluated. In the above
/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
/// = e`, will be further restricted to outer rows that match `A = b AND c =
/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
/// subquery, especially when the original conjunction contains join keys.
pub fn split_subquery_predicates(expr: &mut HirRelationExpr) {
    fn walk_relation(expr: &mut HirRelationExpr) {
        #[allow(deprecated)]
        expr.visit_mut(0, &mut |expr, _| match expr {
            HirRelationExpr::Map { scalars, .. } => {
                for scalar in scalars {
                    walk_scalar(scalar);
                }
            }
            HirRelationExpr::CallTable { exprs, .. } => {
                for expr in exprs {
                    walk_scalar(expr);
                }
            }
            HirRelationExpr::Filter { predicates, .. } => {
                let mut subqueries = vec![];
                for predicate in &mut *predicates {
                    walk_scalar(predicate);
                    extract_conjuncted_subqueries(predicate, &mut subqueries);
                }
                // TODO(benesch): we could be smarter about the order in which
                // we emit subqueries. At the moment we just emit in the order
                // we discovered them, but ideally we'd emit them in an order
                // that accounted for their cost/selectivity. E.g., low-cost,
                // high-selectivity subqueries should go first.
                for subquery in subqueries {
                    predicates.push(subquery);
                }
            }
            _ => (),
        });
    }

    fn walk_scalar(expr: &mut HirScalarExpr) {
        #[allow(deprecated)]
        expr.visit_mut(&mut |expr| match expr {
            HirScalarExpr::Exists(input) | HirScalarExpr::Select(input) => walk_relation(input),
            _ => (),
        })
    }

    fn contains_subquery(expr: &HirScalarExpr) -> bool {
        let mut found = false;
        expr.visit(&mut |expr| match expr {
            HirScalarExpr::Exists(_) | HirScalarExpr::Select(_) => found = true,
            _ => (),
        });
        found
    }

    /// Extracts subqueries from a conjunction into `out`.
    ///
    /// For example, given an expression like
    ///
    /// ```text
    /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
    /// ```
    ///
    /// this function rewrites the expression to
    ///
    /// ```text
    /// a = b AND true AND c = d AND true
    /// ```
    ///
    /// and returns the expression fragments `EXISTS (<subquery 1>)` and
    /// `(<subquery 2>) = e` in the `out` vector.
    fn extract_conjuncted_subqueries(expr: &mut HirScalarExpr, out: &mut Vec<HirScalarExpr>) {
        match expr {
            HirScalarExpr::CallVariadic {
                func: VariadicFunc::And,
                exprs,
            } => {
                exprs
                    .into_iter()
                    .for_each(|e| extract_conjuncted_subqueries(e, out));
            }
            expr if contains_subquery(expr) => {
                out.push(mem::replace(expr, HirScalarExpr::literal_true()))
            }
            _ => (),
        }
    }

    walk_relation(expr)
}

/// Rewrites quantified comparisons into simpler EXISTS operators.
///
/// Note that this transformation is only valid when the expression is
/// used in a context where the distinction between `FALSE` and `NULL`
/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
/// when the inputs to the comparison are non-nullable. This function is careful
/// to only apply the transformation when it is valid to do so.
///
/// ```ignore
/// WHERE (SELECT any(<pred>) FROM <rel>)
/// =>
/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
///
/// WHERE (SELECT all(<pred>) FROM <rel>)
/// =>
/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
/// ```
///
/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
/// M. Elhemali, et al.
pub fn try_simplify_quantified_comparisons(expr: &mut HirRelationExpr) {
    fn walk_relation(expr: &mut HirRelationExpr, outers: &[RelationType]) {
        match expr {
            HirRelationExpr::Map { scalars, input } => {
                walk_relation(input, outers);
                let mut outers = outers.to_vec();
                outers.insert(0, input.typ(&outers, &NO_PARAMS));
                for scalar in scalars {
                    walk_scalar(scalar, &outers, false);
                    let (inner, outers) = outers
                        .split_first_mut()
                        .expect("outers known to have at least one element");
                    let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
                    inner.column_types.push(scalar_type);
                }
            }
            HirRelationExpr::Filter { predicates, input } => {
                walk_relation(input, outers);
                let mut outers = outers.to_vec();
                outers.insert(0, input.typ(&outers, &NO_PARAMS));
                for pred in predicates {
                    walk_scalar(pred, &outers, true);
                }
            }
            HirRelationExpr::CallTable { exprs, .. } => {
                let mut outers = outers.to_vec();
                outers.insert(0, RelationType::empty());
                for scalar in exprs {
                    walk_scalar(scalar, &outers, false);
                }
            }
            HirRelationExpr::Join { left, right, .. } => {
                walk_relation(left, outers);
                let mut outers = outers.to_vec();
                outers.insert(0, left.typ(&outers, &NO_PARAMS));
                walk_relation(right, &outers);
            }
            expr => {
                #[allow(deprecated)]
                let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), ()> {
                    walk_relation(expr, outers);
                    Ok(())
                });
            }
        }
    }

    fn walk_scalar(expr: &mut HirScalarExpr, outers: &[RelationType], mut in_filter: bool) {
        #[allow(deprecated)]
        expr.visit_mut_pre(&mut |e| match e {
            HirScalarExpr::Exists(input) => walk_relation(input, outers),
            HirScalarExpr::Select(input) => {
                walk_relation(input, outers);

                // We're inside of a `(SELECT ...)` subquery. Now let's see if
                // it has the form `(SELECT <any|all>(...) FROM <input>)`.
                // Ideally we could do this with one pattern, but Rust's pattern
                // matching engine is not powerful enough, so we have to do this
                // in stages; the early returns avoid brutal nesting.

                let (func, expr, input) = match &mut **input {
                    HirRelationExpr::Reduce {
                        group_key,
                        aggregates,
                        input,
                        expected_group_size: _,
                    } if group_key.is_empty() && aggregates.len() == 1 => {
                        let agg = &mut aggregates[0];
                        (&agg.func, &mut agg.expr, input)
                    }
                    _ => return,
                };

                if !in_filter && column_type(outers, input, expr).nullable {
                    // Unless we're directly inside of a WHERE, this
                    // transformation is only valid if the expression involved
                    // is non-nullable.
                    return;
                }

                match func {
                    AggregateFunc::Any => {
                        // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
                        // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
                        *e = input.take().filter(vec![expr.take()]).exists();
                    }
                    AggregateFunc::All => {
                        // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
                        // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
                        //
                        // Note that negation of <expr> alone is insufficient.
                        // Consider that `WHERE <pred>` filters out rows if
                        // `<pred>` is false *or* null. To invert the test, we
                        // need `NOT <pred> OR <pred> IS NULL`.
                        let expr = expr.take();
                        let filter = expr.clone().not().or(expr.call_is_null());
                        *e = input.take().filter(vec![filter]).exists().not();
                    }
                    _ => (),
                }
            }
            _ => {
                // As soon as we see *any* scalar expression, we are no longer
                // directly inside of a filter.
                in_filter = false;
            }
        })
    }

    walk_relation(expr, &[])
}

/// An empty parameter type map.
///
/// These transformations are expected to run after parameters are bound, so
/// there is no need to provide any parameter type information.
static NO_PARAMS: LazyLock<BTreeMap<usize, ScalarType>> = LazyLock::new(BTreeMap::new);

fn column_type(
    outers: &[RelationType],
    inner: &HirRelationExpr,
    expr: &HirScalarExpr,
) -> ColumnType {
    let inner_type = inner.typ(outers, &NO_PARAMS);
    expr.typ(outers, &inner_type, &NO_PARAMS)
}

/// # Aims and scope
///
/// The aim here is to amortize the overhead of the MIR window function pattern
/// (see `window_func_applied_to`) by fusing groups of window function calls such
/// that each group can be performed by one instance of the window function MIR
/// pattern.
///
/// For now, we fuse only value window function calls and window aggregations.
/// (We probably won't need to fuse scalar window functions for a long time.)
///
/// For now, we can fuse value window function calls and window aggregations where the
/// A. partition by
/// B. order by
/// C. window frame
/// D. ignore nulls for value window functions and distinct for window aggregations
/// are all the same. (See `extract_options`.)
/// (Later, we could improve this to only need A. to be the same. This would require
/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
/// TODO: As a much simpler intermediate step, at least we should ignore options that
/// don't matter. For example, we should be able to fuse a `lag` that has a default
/// frame with a `first_value` that has some custom frame, because `lag` is not
/// affected by the frame.)
/// Note that we fuse value window function calls and window aggregations separately.
///
/// # Implementation
///
/// At a high level, what we are going to do is look for Maps with more than one window function
/// calls, and for each Map
/// - remove some groups of window function call expressions from the Map's `scalars`;
/// - insert a fused version of each group;
/// - insert some expressions that decompose the results of the fused calls;
/// - update some column references in `scalars`: those that refer to window function results that
///   participated in fusion, as well as those that refer to columns that moved around due to
///   removing and inserting expressions.
/// - insert a Project above the matched Map to permute columns back to their original places.
///
/// It would be tempting to find groups simply by taking a list of all window function calls
/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
/// but a complication is that the possible groups that we could theoretically fuse overlap.
/// This is because when forming groups we need to also take into account column references
/// that point inside the same Map. For example, imagine a Map with the following scalar
/// expressions:
/// C1, E1, C2, C3, where
/// - E1 refers to C1
/// - C3 refers to E1.
/// In this situation, we could either
/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
///   to it);
/// - or fuse C2 and C3.
/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
/// no appropriate place for the fused expression: it would have to be both before and after E1.
///
/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
/// we go through `scalars`, try to put each expression into each of our groups, and the first of
/// these succeed. When trying to put an expression into a group, we need to be mindful about column
/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
/// sanity is that the fused version of each group will be inserted at the place where the first
/// element of the group originally was. This means that the only condition that we need to check on
/// column references when adding an expression to a group is that all column references in a group
/// should be to columns that are earlier than the first element of the group. (No need to check
/// column references in the other direction, i.e., references in other expressions that refer to
/// columns in the group.)
pub fn fuse_window_functions(
    root: &mut HirRelationExpr,
    _context: &crate::plan::lowering::Context,
) -> Result<(), RecursionLimitError> {
    impl HirScalarExpr {
        /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
        /// considers column references that target the root level.
        /// (See `visit_columns_referring_to_root_level`.)
        fn support(&self) -> Vec<usize> {
            let mut result = Vec::new();
            self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
            result
        }

        /// Changes column references in `self` by the given remapping.
        /// Panics if a referred column is not present in `idx_map`!
        fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
            self.visit_columns_referring_to_root_level_mut(&mut |c| {
                *c = idx_map[c];
            });
            self
        }
    }

    /// Those options of a window function call that are relevant for fusion.
    #[derive(PartialEq, Eq)]
    enum WindowFuncCallOptions {
        Value(ValueWindowFuncCallOptions),
        Agg(AggregateWindowFuncCallOptions),
    }
    #[derive(PartialEq, Eq)]
    struct ValueWindowFuncCallOptions {
        partition_by: Vec<HirScalarExpr>,
        outer_order_by: Vec<HirScalarExpr>,
        inner_order_by: Vec<ColumnOrder>,
        window_frame: WindowFrame,
        ignore_nulls: bool,
    }
    #[derive(PartialEq, Eq)]
    struct AggregateWindowFuncCallOptions {
        partition_by: Vec<HirScalarExpr>,
        outer_order_by: Vec<HirScalarExpr>,
        inner_order_by: Vec<ColumnOrder>,
        window_frame: WindowFrame,
        distinct: bool,
    }

    /// Helper function to extract the above options.
    fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
        match call {
            HirScalarExpr::Windowing(WindowExpr {
                func:
                    WindowExprType::Value(ValueWindowExpr {
                        order_by: inner_order_by,
                        window_frame,
                        ignore_nulls,
                        func: _,
                        args: _,
                    }),
                partition_by,
                order_by: outer_order_by,
            }) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
                partition_by: partition_by.clone(),
                outer_order_by: outer_order_by.clone(),
                inner_order_by: inner_order_by.clone(),
                window_frame: window_frame.clone(),
                ignore_nulls: ignore_nulls.clone(),
            }),
            HirScalarExpr::Windowing(WindowExpr {
                func:
                    WindowExprType::Aggregate(AggregateWindowExpr {
                        aggregate_expr: AggregateExpr {
                            distinct,
                            func: _,
                            expr: _,
                        },
                        order_by: inner_order_by,
                        window_frame,
                    }),
                partition_by,
                order_by: outer_order_by,
            }) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
                partition_by: partition_by.clone(),
                outer_order_by: outer_order_by.clone(),
                inner_order_by: inner_order_by.clone(),
                window_frame: window_frame.clone(),
                distinct: distinct.clone(),
            }),
            _ => panic!("extract_options should only be called on value window functions or window aggregations"),
        }
    }

    struct FusionGroup {
        /// The original column index of the first element of the group. (This is an index into the
        /// Map's `scalars` plus the arity of the Map's input.)
        first_col: usize,
        /// The options of all the window function calls in the group. (Must be the same for all the
        /// calls.)
        options: WindowFuncCallOptions,
        /// The calls in the group, with their original column indexes.
        calls: Vec<(usize, HirScalarExpr)>,
    }

    impl FusionGroup {
        /// Creates a window function call that is a fused version of all the calls in the group.
        /// `new_col` is the column index where the fused call will be inserted at.
        fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
            let fused = match self.options {
                WindowFuncCallOptions::Value(options) => {
                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
                        .calls
                        .iter()
                        .map(|(_idx, call)| {
                            if let HirScalarExpr::Windowing(WindowExpr {
                                func:
                                    WindowExprType::Value(ValueWindowExpr {
                                        func,
                                        args,
                                        order_by: _,
                                        window_frame: _,
                                        ignore_nulls: _,
                                    }),
                                partition_by: _,
                                order_by: _,
                            }) = call
                            {
                                (func.clone(), (**args).clone())
                            } else {
                                panic!("unknown window function in FusionGroup")
                            }
                        })
                        .unzip();
                    let fused_args = HirScalarExpr::CallVariadic {
                        func: VariadicFunc::RecordCreate {
                            // These field names are not important, because this record will only be an
                            // intermediate expression, which we'll manipulate further before it ends up
                            // anywhere where a column name would be visible.
                            field_names: iter::repeat(ColumnName::from(""))
                                .take(fused_args.len())
                                .collect(),
                        },
                        exprs: fused_args,
                    };
                    HirScalarExpr::Windowing(WindowExpr {
                        func: WindowExprType::Value(ValueWindowExpr {
                            func: ValueWindowFunc::Fused(fused_funcs),
                            args: Box::new(fused_args),
                            order_by: options.inner_order_by,
                            window_frame: options.window_frame,
                            ignore_nulls: options.ignore_nulls,
                        }),
                        partition_by: options.partition_by,
                        order_by: options.outer_order_by,
                    })
                }
                WindowFuncCallOptions::Agg(options) => {
                    let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
                        .calls
                        .iter()
                        .map(|(_idx, call)| {
                            if let HirScalarExpr::Windowing(WindowExpr {
                                func:
                                    WindowExprType::Aggregate(AggregateWindowExpr {
                                        aggregate_expr:
                                            AggregateExpr {
                                                func,
                                                expr,
                                                distinct: _,
                                            },
                                        order_by: _,
                                        window_frame: _,
                                    }),
                                partition_by: _,
                                order_by: _,
                            }) = call
                            {
                                (func.clone(), (**expr).clone())
                            } else {
                                panic!("unknown window function in FusionGroup")
                            }
                        })
                        .unzip();
                    let fused_args = HirScalarExpr::CallVariadic {
                        func: VariadicFunc::RecordCreate {
                            field_names: iter::repeat(ColumnName::from(""))
                                .take(fused_args.len())
                                .collect(),
                        },
                        exprs: fused_args,
                    };
                    HirScalarExpr::Windowing(WindowExpr {
                        func: WindowExprType::Aggregate(AggregateWindowExpr {
                            aggregate_expr: AggregateExpr {
                                func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
                                expr: Box::new(fused_args),
                                distinct: options.distinct,
                            },
                            order_by: options.inner_order_by,
                            window_frame: options.window_frame,
                        }),
                        partition_by: options.partition_by,
                        order_by: options.outer_order_by,
                    })
                }
            };

            let decompositions = (0..self.calls.len())
                .map(|field| HirScalarExpr::CallUnary {
                    func: UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)),
                    expr: Box::new(HirScalarExpr::Column(ColumnRef {
                        level: 0,
                        column: new_col,
                    })),
                })
                .collect();

            (fused, decompositions)
        }
    }

    let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
        // Look for calls only at the root of scalar expressions. This is enough
        // because they are always there, see 72e84bb78.
        match scalar_expr {
            HirScalarExpr::Windowing(WindowExpr {
                func: WindowExprType::Value(ValueWindowExpr { func, .. }),
                ..
            }) => {
                // Exclude those calls that are already fused. (We shouldn't currently
                // encounter these, because we just do one pass, but it's better to be
                // robust against future code changes.)
                !matches!(func, ValueWindowFunc::Fused(..))
            }
            HirScalarExpr::Windowing(WindowExpr {
                func:
                    WindowExprType::Aggregate(AggregateWindowExpr {
                        aggregate_expr: AggregateExpr { func, .. },
                        ..
                    }),
                ..
            }) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
            _ => false,
        }
    };

    root.try_visit_mut_post(&mut |rel_expr| {
        match rel_expr {
            HirRelationExpr::Map { input, scalars } => {
                // There will be various variable names involving `idx` or `col`:
                // - `idx` will always be an index into `scalars` or something similar,
                // - `col` will always be a column index,
                //   which is often `arity_before_map` + an index into `scalars`.
                let arity_before_map = input.arity();
                let orig_num_scalars = scalars.len();

                // Collect all value window function calls and window aggregations with their column
                // indexes.
                let value_or_agg_window_func_calls = scalars
                    .iter()
                    .enumerate()
                    .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
                    .map(|(idx, call)| (idx + arity_before_map, call.clone()))
                    .collect_vec();
                // Exit early if obviously no chance for fusion.
                if value_or_agg_window_func_calls.len() <= 1 {
                    // Note that we are doing this only for performance. All plans should be exactly
                    // the same even if we comment out the following line.
                    return Ok(());
                }

                // Determine the fusion groups. (Each group will later be fused into one window
                // function call.)
                // Note that this has a quadratic run time with value_or_agg_window_func_calls in
                // the worst case. However, this is fine even with 1000 window function calls.
                let mut groups: Vec<FusionGroup> = Vec::new();
                for (col, call) in value_or_agg_window_func_calls {
                    let options = extract_options(&call);
                    let support = call.support();
                    let to_fuse_with = groups
                        .iter_mut()
                        .filter(|group| {
                            group.options == options && support.iter().all(|c| *c < group.first_col)
                        })
                        .next();
                    if let Some(group) = to_fuse_with {
                        group.calls.push((col, call.clone()));
                    } else {
                        groups.push(FusionGroup {
                            first_col: col,
                            options,
                            calls: vec![(col, call.clone())],
                        });
                    }
                }

                // No fusion to do on groups of 1.
                groups.retain(|g| g.calls.len() > 1);

                let removals: BTreeSet<usize> = groups
                    .iter()
                    .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
                    .collect();

                // Mutate `scalars`.
                // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
                // `groups` is already sorted by `first_col` due to the way it was constructed.)
                // We also compute a remapping of old indexes to new indexes as we go.
                let mut groups_it = groups.drain(..).peekable();
                let mut group = groups_it.next();
                let mut remap = BTreeMap::new();
                remap.extend((0..arity_before_map).map(|col| (col, col)));
                let mut new_col: usize = arity_before_map;
                let mut new_scalars = Vec::new();
                for (old_col, e) in scalars
                    .drain(..)
                    .enumerate()
                    .map(|(idx, e)| (idx + arity_before_map, e))
                {
                    if group.as_ref().is_some_and(|g| g.first_col == old_col) {
                        // The current expression will be fused away, and a fused expression will
                        // appear in its place. Additionally, some new expressions will be inserted
                        // after the fused expression, to decompose the record that is the result of
                        // the fused call.
                        assert!(removals.contains(&old_col));
                        let group_unwrapped = group.expect("checked above");
                        let calls_cols = group_unwrapped
                            .calls
                            .iter()
                            .map(|(col, _call)| *col)
                            .collect_vec();
                        let (fused, decompositions) = group_unwrapped.fuse(new_col);
                        new_scalars.push(fused.remap(&remap));
                        new_scalars.extend(decompositions); // (no remapping needed)
                        new_col += 1;
                        for call_old_col in calls_cols {
                            let present = remap.insert(call_old_col, new_col);
                            assert!(present.is_none());
                            new_col += 1;
                        }
                        group = groups_it.next();
                    } else if removals.contains(&old_col) {
                        assert!(remap.contains_key(&old_col));
                    } else {
                        new_scalars.push(e.remap(&remap));
                        let present = remap.insert(old_col, new_col);
                        assert!(present.is_none());
                        new_col += 1;
                    }
                }
                *scalars = new_scalars;
                assert_eq!(remap.len(), arity_before_map + orig_num_scalars);

                // Add a project to permute columns back to their original places.
                *rel_expr = rel_expr.take().project(
                    (0..arity_before_map)
                        .chain((0..orig_num_scalars).map(|idx| {
                            *remap
                                .get(&(idx + arity_before_map))
                                .expect("all columns should be present by now")
                        }))
                        .collect(),
                );

                assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
            }
            _ => {}
        }
        Ok(())
    })
}