mz_sql/plan/transform_hir.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::func::variadic::RecordCreate;
19use mz_expr::visit::Visit;
20use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
21use mz_ore::stack::RecursionLimitError;
22use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
23
24use crate::plan::hir::{
25 AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
26 ValueWindowExpr, ValueWindowFunc, WindowExpr,
27};
28use crate::plan::{AggregateExpr, WindowExprType};
29
30/// Rewrites predicates that contain subqueries so that the subqueries
31/// appear in their own later predicate when possible.
32///
33/// For example, this function rewrites this expression
34///
35/// ```text
36/// Filter {
37/// predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
38/// }
39/// ```
40///
41/// like so:
42///
43/// ```text
44/// Filter {
45/// predicates: [
46/// a = b AND c = d,
47/// EXISTS (<subquery>),
48/// (<subquery 2>) = e,
49/// ]
50/// }
51/// ```
52///
53/// The rewrite causes decorrelation to incorporate prior predicates into
54/// the outer relation upon which the subquery is evaluated. In the above
55/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
56/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
57/// = e`, will be further restricted to outer rows that match `A = b AND c =
58/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
59/// subquery, especially when the original conjunction contains join keys.
60pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61 fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
62 #[allow(deprecated)]
63 expr.visit_mut_fallible(0, &mut |expr, _| {
64 match expr {
65 HirRelationExpr::Map { scalars, .. } => {
66 for scalar in scalars {
67 walk_scalar(scalar)?;
68 }
69 }
70 HirRelationExpr::CallTable { exprs, .. } => {
71 for expr in exprs {
72 walk_scalar(expr)?;
73 }
74 }
75 HirRelationExpr::Filter { predicates, .. } => {
76 let mut subqueries = vec![];
77 for predicate in &mut *predicates {
78 walk_scalar(predicate)?;
79 extract_conjuncted_subqueries(predicate, &mut subqueries)?;
80 }
81 // TODO(benesch): we could be smarter about the order in which
82 // we emit subqueries. At the moment we just emit in the order
83 // we discovered them, but ideally we'd emit them in an order
84 // that accounted for their cost/selectivity. E.g., low-cost,
85 // high-selectivity subqueries should go first.
86 for subquery in subqueries {
87 predicates.push(subquery);
88 }
89 }
90 _ => (),
91 }
92 Ok(())
93 })
94 }
95
96 fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
97 expr.try_visit_mut_post(&mut |expr| {
98 match expr {
99 HirScalarExpr::Exists(input, _name) | HirScalarExpr::Select(input, _name) => {
100 walk_relation(input)?
101 }
102 _ => (),
103 }
104 Ok(())
105 })
106 }
107
108 fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
109 let mut found = false;
110 expr.visit_pre(&mut |expr| match expr {
111 HirScalarExpr::Exists(..) | HirScalarExpr::Select(..) => found = true,
112 _ => (),
113 })?;
114 Ok(found)
115 }
116
117 /// Extracts subqueries from a conjunction into `out`.
118 ///
119 /// For example, given an expression like
120 ///
121 /// ```text
122 /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
123 /// ```
124 ///
125 /// this function rewrites the expression to
126 ///
127 /// ```text
128 /// a = b AND true AND c = d AND true
129 /// ```
130 ///
131 /// and returns the expression fragments `EXISTS (<subquery 1>)` and
132 /// `(<subquery 2>) = e` in the `out` vector.
133 fn extract_conjuncted_subqueries(
134 expr: &mut HirScalarExpr,
135 out: &mut Vec<HirScalarExpr>,
136 ) -> Result<(), RecursionLimitError> {
137 match expr {
138 HirScalarExpr::CallVariadic {
139 func: VariadicFunc::And(_),
140 exprs,
141 name: _,
142 } => {
143 exprs
144 .into_iter()
145 .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
146 }
147 expr if contains_subquery(expr)? => {
148 out.push(mem::replace(expr, HirScalarExpr::literal_true()))
149 }
150 _ => (),
151 }
152 Ok(())
153 }
154
155 walk_relation(expr)
156}
157
158/// Rewrites quantified comparisons into simpler EXISTS operators.
159///
160/// Note that this transformation is only valid when the expression is
161/// used in a context where the distinction between `FALSE` and `NULL`
162/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
163/// when the inputs to the comparison are non-nullable. This function is careful
164/// to only apply the transformation when it is valid to do so.
165///
166/// ```ignore
167/// WHERE (SELECT any(<pred>) FROM <rel>)
168/// =>
169/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
170///
171/// WHERE (SELECT all(<pred>) FROM <rel>)
172/// =>
173/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
174/// ```
175///
176/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
177/// M. Elhemali, et al.
178pub fn try_simplify_quantified_comparisons(
179 expr: &mut HirRelationExpr,
180 simplify_join_on: bool,
181) -> Result<(), RecursionLimitError> {
182 fn walk_relation(
183 expr: &mut HirRelationExpr,
184 outers: &[SqlRelationType],
185 simplify_join_on: bool,
186 ) -> Result<(), RecursionLimitError> {
187 match expr {
188 HirRelationExpr::Map { scalars, input } => {
189 walk_relation(input, outers, simplify_join_on)?;
190 let mut outers = outers.to_vec();
191 outers.insert(0, input.typ(&outers, &NO_PARAMS));
192 for scalar in scalars {
193 walk_scalar(scalar, &outers, false, simplify_join_on)?;
194 let (inner, outers) = outers
195 .split_first_mut()
196 .expect("outers known to have at least one element");
197 let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
198 inner.column_types.push(scalar_type);
199 }
200 }
201 HirRelationExpr::Filter { predicates, input } => {
202 walk_relation(input, outers, simplify_join_on)?;
203 let mut outers = outers.to_vec();
204 outers.insert(0, input.typ(&outers, &NO_PARAMS));
205 for pred in predicates {
206 walk_scalar(pred, &outers, true, simplify_join_on)?;
207 }
208 }
209 HirRelationExpr::CallTable { exprs, .. } => {
210 let mut outers = outers.to_vec();
211 outers.insert(0, SqlRelationType::empty());
212 for scalar in exprs {
213 walk_scalar(scalar, &outers, false, simplify_join_on)?;
214 }
215 }
216 HirRelationExpr::Join {
217 left, right, on, ..
218 } => {
219 walk_relation(left, outers, simplify_join_on)?;
220 let left_type = left.typ(outers, &NO_PARAMS);
221 let mut outers = outers.to_vec();
222 outers.insert(0, left_type);
223 walk_relation(right, &outers, simplify_join_on)?;
224 if simplify_join_on {
225 // Build outers with the full join output type, since the
226 // ON clause can reference columns from both sides.
227 let right_type = right.typ(&outers, &NO_PARAMS);
228 let mut join_columns = outers[0].column_types.clone();
229 join_columns.extend(right_type.column_types);
230 outers[0] = SqlRelationType::new(join_columns);
231 walk_scalar(on, &outers, true, simplify_join_on)?;
232 }
233 }
234 expr => {
235 #[allow(deprecated)]
236 let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
237 walk_relation(expr, outers, simplify_join_on)
238 });
239 }
240 }
241 Ok(())
242 }
243
244 fn walk_scalar(
245 expr: &mut HirScalarExpr,
246 outers: &[SqlRelationType],
247 mut in_filter: bool,
248 simplify_join_on: bool,
249 ) -> Result<(), RecursionLimitError> {
250 expr.try_visit_mut_pre(&mut |e| {
251 match e {
252 HirScalarExpr::Exists(input, _name) => {
253 walk_relation(input, outers, simplify_join_on)?
254 }
255 HirScalarExpr::Select(input, _name) => {
256 walk_relation(input, outers, simplify_join_on)?;
257
258 // We're inside a `(SELECT ...)` subquery. Now let's see if
259 // it has the form `(SELECT <any|all>(...) FROM <input>)`.
260 // Ideally we could do this with one pattern, but Rust's pattern
261 // matching engine is not powerful enough, so we have to do this
262 // in stages; the early returns avoid brutal nesting.
263
264 let (func, expr, input) = match &mut **input {
265 HirRelationExpr::Reduce {
266 group_key,
267 aggregates,
268 input,
269 expected_group_size: _,
270 } if group_key.is_empty() && aggregates.len() == 1 => {
271 let agg = &mut aggregates[0];
272 (&agg.func, &mut agg.expr, input)
273 }
274 _ => return Ok(()),
275 };
276
277 if !in_filter && column_type(outers, input, expr).nullable {
278 // Unless we're directly inside a WHERE, this
279 // transformation is only valid if the expression involved
280 // is non-nullable.
281 return Ok(());
282 }
283
284 match func {
285 AggregateFunc::Any => {
286 // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
287 // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
288 *e = input.take().filter(vec![expr.take()]).exists();
289 }
290 AggregateFunc::All => {
291 // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
292 // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
293 //
294 // Note that negation of <expr> alone is insufficient.
295 // Consider that `WHERE <pred>` filters out rows if
296 // `<pred>` is false *or* null. To invert the test, we
297 // need `NOT <pred> OR <pred> IS NULL`.
298 let expr = expr.take();
299 let filter = expr.clone().not().or(expr.call_is_null());
300 *e = input.take().filter(vec![filter]).exists().not();
301 }
302 _ => (),
303 }
304 }
305 _ => {
306 // As soon as we see *any* scalar expression, we are no longer
307 // directly inside a filter.
308 in_filter = false;
309 }
310 }
311 Ok(())
312 })
313 }
314
315 walk_relation(expr, &[], simplify_join_on)
316}
317
318/// An empty parameter type map.
319///
320/// These transformations are expected to run after parameters are bound, so
321/// there is no need to provide any parameter type information.
322static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
323
324fn column_type(
325 outers: &[SqlRelationType],
326 inner: &HirRelationExpr,
327 expr: &HirScalarExpr,
328) -> SqlColumnType {
329 let inner_type = inner.typ(outers, &NO_PARAMS);
330 expr.typ(outers, &inner_type, &NO_PARAMS)
331}
332
333impl HirScalarExpr {
334 /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
335 /// considers column references that target the root level.
336 /// (See `visit_columns_referring_to_root_level`.)
337 fn support(&self) -> Vec<usize> {
338 let mut result = Vec::new();
339 self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
340 result
341 }
342
343 /// Changes column references in `self` by the given remapping.
344 /// Panics if a referred column is not present in `idx_map`!
345 fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
346 self.visit_columns_referring_to_root_level_mut(&mut |c| {
347 *c = idx_map[c];
348 });
349 self
350 }
351}
352
353/// # Aims and scope
354///
355/// The aim here is to amortize the overhead of the MIR window function pattern
356/// (see `window_func_applied_to`) by fusing groups of window function calls such
357/// that each group can be performed by one instance of the window function MIR
358/// pattern.
359///
360/// For now, we fuse only value window function calls and window aggregations.
361/// (We probably won't need to fuse scalar window functions for a long time.)
362///
363/// For now, we can fuse value window function calls and window aggregations where the
364/// A. partition by
365/// B. order by
366/// C. window frame
367/// D. ignore nulls for value window functions and distinct for window aggregations
368/// are all the same. (See `extract_options`.)
369/// (Later, we could improve this to only need A. to be the same. This would require
370/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
371/// TODO: As a much simpler intermediate step, at least we should ignore options that
372/// don't matter. For example, we should be able to fuse a `lag` that has a default
373/// frame with a `first_value` that has some custom frame, because `lag` is not
374/// affected by the frame.)
375/// Note that we fuse value window function calls and window aggregations separately.
376///
377/// # Implementation
378///
379/// At a high level, what we are going to do is look for Maps with more than one window function
380/// calls, and for each Map
381/// - remove some groups of window function call expressions from the Map's `scalars`;
382/// - insert a fused version of each group;
383/// - insert some expressions that decompose the results of the fused calls;
384/// - update some column references in `scalars`: those that refer to window function results that
385/// participated in fusion, as well as those that refer to columns that moved around due to
386/// removing and inserting expressions.
387/// - insert a Project above the matched Map to permute columns back to their original places.
388///
389/// It would be tempting to find groups simply by taking a list of all window function calls
390/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
391/// but a complication is that the possible groups that we could theoretically fuse overlap.
392/// This is because when forming groups we need to also take into account column references
393/// that point inside the same Map. For example, imagine a Map with the following scalar
394/// expressions:
395/// C1, E1, C2, C3, where
396/// - E1 refers to C1
397/// - C3 refers to E1.
398/// In this situation, we could either
399/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
400/// to it);
401/// - or fuse C2 and C3.
402/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
403/// no appropriate place for the fused expression: it would have to be both before and after E1.
404///
405/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
406/// we go through `scalars`, try to put each expression into each of our groups, and the first of
407/// these succeed. When trying to put an expression into a group, we need to be mindful about column
408/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
409/// sanity is that the fused version of each group will be inserted at the place where the first
410/// element of the group originally was. This means that the only condition that we need to check on
411/// column references when adding an expression to a group is that all column references in a group
412/// should be to columns that are earlier than the first element of the group. (No need to check
413/// column references in the other direction, i.e., references in other expressions that refer to
414/// columns in the group.)
415pub fn fuse_window_functions(
416 root: &mut HirRelationExpr,
417 _context: &crate::plan::lowering::Context,
418) -> Result<(), RecursionLimitError> {
419 /// Those options of a window function call that are relevant for fusion.
420 #[derive(PartialEq, Eq)]
421 enum WindowFuncCallOptions {
422 Value(ValueWindowFuncCallOptions),
423 Agg(AggregateWindowFuncCallOptions),
424 }
425 #[derive(PartialEq, Eq)]
426 struct ValueWindowFuncCallOptions {
427 partition_by: Vec<HirScalarExpr>,
428 outer_order_by: Vec<HirScalarExpr>,
429 inner_order_by: Vec<ColumnOrder>,
430 window_frame: WindowFrame,
431 ignore_nulls: bool,
432 }
433 #[derive(PartialEq, Eq)]
434 struct AggregateWindowFuncCallOptions {
435 partition_by: Vec<HirScalarExpr>,
436 outer_order_by: Vec<HirScalarExpr>,
437 inner_order_by: Vec<ColumnOrder>,
438 window_frame: WindowFrame,
439 distinct: bool,
440 }
441
442 /// Helper function to extract the above options.
443 fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
444 match call {
445 HirScalarExpr::Windowing(
446 WindowExpr {
447 func:
448 WindowExprType::Value(ValueWindowExpr {
449 order_by: inner_order_by,
450 window_frame,
451 ignore_nulls,
452 func: _,
453 args: _,
454 }),
455 partition_by,
456 order_by: outer_order_by,
457 },
458 _name,
459 ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
460 partition_by: partition_by.clone(),
461 outer_order_by: outer_order_by.clone(),
462 inner_order_by: inner_order_by.clone(),
463 window_frame: window_frame.clone(),
464 ignore_nulls: ignore_nulls.clone(),
465 }),
466 HirScalarExpr::Windowing(
467 WindowExpr {
468 func:
469 WindowExprType::Aggregate(AggregateWindowExpr {
470 aggregate_expr:
471 AggregateExpr {
472 distinct,
473 func: _,
474 expr: _,
475 },
476 order_by: inner_order_by,
477 window_frame,
478 }),
479 partition_by,
480 order_by: outer_order_by,
481 },
482 _name,
483 ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
484 partition_by: partition_by.clone(),
485 outer_order_by: outer_order_by.clone(),
486 inner_order_by: inner_order_by.clone(),
487 window_frame: window_frame.clone(),
488 distinct: distinct.clone(),
489 }),
490 _ => panic!(
491 "extract_options should only be called on value window functions or window aggregations"
492 ),
493 }
494 }
495
496 struct FusionGroup {
497 /// The original column index of the first element of the group. (This is an index into the
498 /// Map's `scalars` plus the arity of the Map's input.)
499 first_col: usize,
500 /// The options of all the window function calls in the group. (Must be the same for all the
501 /// calls.)
502 options: WindowFuncCallOptions,
503 /// The calls in the group, with their original column indexes.
504 calls: Vec<(usize, HirScalarExpr)>,
505 }
506
507 impl FusionGroup {
508 /// Creates a window function call that is a fused version of all the calls in the group.
509 /// `new_col` is the column index where the fused call will be inserted at.
510 fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
511 let fused = match self.options {
512 WindowFuncCallOptions::Value(options) => {
513 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
514 .calls
515 .iter()
516 .map(|(_idx, call)| {
517 if let HirScalarExpr::Windowing(
518 WindowExpr {
519 func:
520 WindowExprType::Value(ValueWindowExpr {
521 func,
522 args,
523 order_by: _,
524 window_frame: _,
525 ignore_nulls: _,
526 }),
527 partition_by: _,
528 order_by: _,
529 },
530 _name,
531 ) = call
532 {
533 (func.clone(), (**args).clone())
534 } else {
535 panic!("unknown window function in FusionGroup")
536 }
537 })
538 .unzip();
539 let fused_args = HirScalarExpr::call_variadic(
540 RecordCreate {
541 // These field names are not important, because this record will only be an
542 // intermediate expression, which we'll manipulate further before it ends up
543 // anywhere where a column name would be visible.
544 field_names: iter::repeat(ColumnName::from(""))
545 .take(fused_args.len())
546 .collect(),
547 },
548 fused_args,
549 );
550 HirScalarExpr::windowing(WindowExpr {
551 func: WindowExprType::Value(ValueWindowExpr {
552 func: ValueWindowFunc::Fused(fused_funcs),
553 args: Box::new(fused_args),
554 order_by: options.inner_order_by,
555 window_frame: options.window_frame,
556 ignore_nulls: options.ignore_nulls,
557 }),
558 partition_by: options.partition_by,
559 order_by: options.outer_order_by,
560 })
561 }
562 WindowFuncCallOptions::Agg(options) => {
563 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
564 .calls
565 .iter()
566 .map(|(_idx, call)| {
567 if let HirScalarExpr::Windowing(
568 WindowExpr {
569 func:
570 WindowExprType::Aggregate(AggregateWindowExpr {
571 aggregate_expr:
572 AggregateExpr {
573 func,
574 expr,
575 distinct: _,
576 },
577 order_by: _,
578 window_frame: _,
579 }),
580 partition_by: _,
581 order_by: _,
582 },
583 _name,
584 ) = call
585 {
586 (func.clone(), (**expr).clone())
587 } else {
588 panic!("unknown window function in FusionGroup")
589 }
590 })
591 .unzip();
592 let fused_args = HirScalarExpr::call_variadic(
593 RecordCreate {
594 field_names: iter::repeat(ColumnName::from(""))
595 .take(fused_args.len())
596 .collect(),
597 },
598 fused_args,
599 );
600 HirScalarExpr::windowing(WindowExpr {
601 func: WindowExprType::Aggregate(AggregateWindowExpr {
602 aggregate_expr: AggregateExpr {
603 func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
604 expr: Box::new(fused_args),
605 distinct: options.distinct,
606 },
607 order_by: options.inner_order_by,
608 window_frame: options.window_frame,
609 }),
610 partition_by: options.partition_by,
611 order_by: options.outer_order_by,
612 })
613 }
614 };
615
616 let decompositions = (0..self.calls.len())
617 .map(|field| {
618 HirScalarExpr::column(new_col)
619 .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
620 })
621 .collect();
622
623 (fused, decompositions)
624 }
625 }
626
627 let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
628 // Look for calls only at the root of scalar expressions. This is enough
629 // because they are always there, see 72e84bb78.
630 match scalar_expr {
631 HirScalarExpr::Windowing(
632 WindowExpr {
633 func: WindowExprType::Value(ValueWindowExpr { func, .. }),
634 ..
635 },
636 _name,
637 ) => {
638 // Exclude those calls that are already fused. (We shouldn't currently
639 // encounter these, because we just do one pass, but it's better to be
640 // robust against future code changes.)
641 !matches!(func, ValueWindowFunc::Fused(..))
642 }
643 HirScalarExpr::Windowing(
644 WindowExpr {
645 func:
646 WindowExprType::Aggregate(AggregateWindowExpr {
647 aggregate_expr: AggregateExpr { func, .. },
648 ..
649 }),
650 ..
651 },
652 _name,
653 ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
654 _ => false,
655 }
656 };
657
658 root.try_visit_mut_post(&mut |rel_expr| {
659 match rel_expr {
660 HirRelationExpr::Map { input, scalars } => {
661 // There will be various variable names involving `idx` or `col`:
662 // - `idx` will always be an index into `scalars` or something similar,
663 // - `col` will always be a column index,
664 // which is often `arity_before_map` + an index into `scalars`.
665 let arity_before_map = input.arity();
666 let orig_num_scalars = scalars.len();
667
668 // Collect all value window function calls and window aggregations with their column
669 // indexes.
670 let value_or_agg_window_func_calls = scalars
671 .iter()
672 .enumerate()
673 .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
674 .map(|(idx, call)| (idx + arity_before_map, call.clone()))
675 .collect_vec();
676 // Exit early if obviously no chance for fusion.
677 if value_or_agg_window_func_calls.len() <= 1 {
678 // Note that we are doing this only for performance. All plans should be exactly
679 // the same even if we comment out the following line.
680 return Ok(());
681 }
682
683 // Determine the fusion groups. (Each group will later be fused into one window
684 // function call.)
685 // Note that this has a quadratic run time with value_or_agg_window_func_calls in
686 // the worst case. However, this is fine even with 1000 window function calls.
687 let mut groups: Vec<FusionGroup> = Vec::new();
688 for (col, call) in value_or_agg_window_func_calls {
689 let options = extract_options(&call);
690 let support = call.support();
691 let to_fuse_with = groups
692 .iter_mut()
693 .filter(|group| {
694 group.options == options && support.iter().all(|c| *c < group.first_col)
695 })
696 .next();
697 if let Some(group) = to_fuse_with {
698 group.calls.push((col, call.clone()));
699 } else {
700 groups.push(FusionGroup {
701 first_col: col,
702 options,
703 calls: vec![(col, call.clone())],
704 });
705 }
706 }
707
708 // No fusion to do on groups of 1.
709 groups.retain(|g| g.calls.len() > 1);
710
711 let removals: BTreeSet<usize> = groups
712 .iter()
713 .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
714 .collect();
715
716 // Mutate `scalars`.
717 // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
718 // `groups` is already sorted by `first_col` due to the way it was constructed.)
719 // We also compute a remapping of old indexes to new indexes as we go.
720 let mut groups_it = groups.drain(..).peekable();
721 let mut group = groups_it.next();
722 let mut remap = BTreeMap::new();
723 remap.extend((0..arity_before_map).map(|col| (col, col)));
724 let mut new_col: usize = arity_before_map;
725 let mut new_scalars = Vec::new();
726 for (old_col, e) in scalars
727 .drain(..)
728 .enumerate()
729 .map(|(idx, e)| (idx + arity_before_map, e))
730 {
731 if group.as_ref().is_some_and(|g| g.first_col == old_col) {
732 // The current expression will be fused away, and a fused expression will
733 // appear in its place. Additionally, some new expressions will be inserted
734 // after the fused expression, to decompose the record that is the result of
735 // the fused call.
736 assert!(removals.contains(&old_col));
737 let group_unwrapped = group.expect("checked above");
738 let calls_cols = group_unwrapped
739 .calls
740 .iter()
741 .map(|(col, _call)| *col)
742 .collect_vec();
743 let (fused, decompositions) = group_unwrapped.fuse(new_col);
744 new_scalars.push(fused.remap(&remap));
745 new_scalars.extend(decompositions); // (no remapping needed)
746 new_col += 1;
747 for call_old_col in calls_cols {
748 let present = remap.insert(call_old_col, new_col);
749 assert!(present.is_none());
750 new_col += 1;
751 }
752 group = groups_it.next();
753 } else if removals.contains(&old_col) {
754 assert!(remap.contains_key(&old_col));
755 } else {
756 new_scalars.push(e.remap(&remap));
757 let present = remap.insert(old_col, new_col);
758 assert!(present.is_none());
759 new_col += 1;
760 }
761 }
762 *scalars = new_scalars;
763 assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
764
765 // Add a project to permute columns back to their original places.
766 *rel_expr = rel_expr.take().project(
767 (0..arity_before_map)
768 .chain((0..orig_num_scalars).map(|idx| {
769 *remap
770 .get(&(idx + arity_before_map))
771 .expect("all columns should be present by now")
772 }))
773 .collect(),
774 );
775
776 assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
777 }
778 _ => {}
779 }
780 Ok(())
781 })
782}