mz_sql/plan/transform_hir.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::func::variadic::RecordCreate;
19use mz_expr::visit::Visit;
20use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
21use mz_ore::stack::RecursionLimitError;
22use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
23
24use crate::plan::hir::{
25 AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
26 ValueWindowExpr, ValueWindowFunc, WindowExpr,
27};
28use crate::plan::{AggregateExpr, WindowExprType};
29
30/// Rewrites predicates that contain subqueries so that the subqueries
31/// appear in their own later predicate when possible.
32///
33/// For example, this function rewrites this expression
34///
35/// ```text
36/// Filter {
37/// predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
38/// }
39/// ```
40///
41/// like so:
42///
43/// ```text
44/// Filter {
45/// predicates: [
46/// a = b AND c = d,
47/// EXISTS (<subquery>),
48/// (<subquery 2>) = e,
49/// ]
50/// }
51/// ```
52///
53/// The rewrite causes decorrelation to incorporate prior predicates into
54/// the outer relation upon which the subquery is evaluated. In the above
55/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
56/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
57/// = e`, will be further restricted to outer rows that match `A = b AND c =
58/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
59/// subquery, especially when the original conjunction contains join keys.
60pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61 fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
62 #[allow(deprecated)]
63 expr.visit_mut_fallible(0, &mut |expr, _| {
64 match expr {
65 HirRelationExpr::Map { scalars, .. } => {
66 for scalar in scalars {
67 walk_scalar(scalar)?;
68 }
69 }
70 HirRelationExpr::CallTable { exprs, .. } => {
71 for expr in exprs {
72 walk_scalar(expr)?;
73 }
74 }
75 HirRelationExpr::Filter { predicates, .. } => {
76 let mut subqueries = vec![];
77 for predicate in &mut *predicates {
78 walk_scalar(predicate)?;
79 extract_conjuncted_subqueries(predicate, &mut subqueries)?;
80 }
81 // TODO(benesch): we could be smarter about the order in which
82 // we emit subqueries. At the moment we just emit in the order
83 // we discovered them, but ideally we'd emit them in an order
84 // that accounted for their cost/selectivity. E.g., low-cost,
85 // high-selectivity subqueries should go first.
86 for subquery in subqueries {
87 predicates.push(subquery);
88 }
89 }
90 _ => (),
91 }
92 Ok(())
93 })
94 }
95
96 fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
97 expr.try_visit_direct_subqueries_mut(&mut walk_relation)
98 }
99
100 fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
101 let mut found = false;
102 expr.try_visit_direct_subqueries(|_| {
103 found = true;
104 Ok(())
105 })?;
106 Ok(found)
107 }
108
109 /// Extracts subqueries from a conjunction into `out`.
110 ///
111 /// For example, given an expression like
112 ///
113 /// ```text
114 /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
115 /// ```
116 ///
117 /// this function rewrites the expression to
118 ///
119 /// ```text
120 /// a = b AND true AND c = d AND true
121 /// ```
122 ///
123 /// and returns the expression fragments `EXISTS (<subquery 1>)` and
124 /// `(<subquery 2>) = e` in the `out` vector.
125 fn extract_conjuncted_subqueries(
126 expr: &mut HirScalarExpr,
127 out: &mut Vec<HirScalarExpr>,
128 ) -> Result<(), RecursionLimitError> {
129 match expr {
130 HirScalarExpr::CallVariadic {
131 func: VariadicFunc::And(_),
132 exprs,
133 name: _,
134 } => {
135 exprs
136 .into_iter()
137 .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
138 }
139 expr if contains_subquery(expr)? => {
140 out.push(mem::replace(expr, HirScalarExpr::literal_true()))
141 }
142 _ => (),
143 }
144 Ok(())
145 }
146
147 walk_relation(expr)
148}
149
150/// Rewrites quantified comparisons into simpler EXISTS operators.
151///
152/// Note that this transformation is only valid when the expression is
153/// used in a context where the distinction between `FALSE` and `NULL`
154/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
155/// when the inputs to the comparison are non-nullable. This function is careful
156/// to only apply the transformation when it is valid to do so.
157///
158/// ```ignore
159/// WHERE (SELECT any(<pred>) FROM <rel>)
160/// =>
161/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
162///
163/// WHERE (SELECT all(<pred>) FROM <rel>)
164/// =>
165/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
166/// ```
167///
168/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
169/// M. Elhemali, et al.
170pub fn try_simplify_quantified_comparisons(
171 expr: &mut HirRelationExpr,
172 simplify_join_on: bool,
173) -> Result<(), RecursionLimitError> {
174 fn walk_relation(
175 expr: &mut HirRelationExpr,
176 outers: &[SqlRelationType],
177 simplify_join_on: bool,
178 ) -> Result<(), RecursionLimitError> {
179 match expr {
180 HirRelationExpr::Map { scalars, input } => {
181 walk_relation(input, outers, simplify_join_on)?;
182 let mut outers = outers.to_vec();
183 outers.insert(0, input.typ(&outers, &NO_PARAMS));
184 for scalar in scalars {
185 walk_scalar(scalar, &outers, false, simplify_join_on)?;
186 let (inner, outers) = outers
187 .split_first_mut()
188 .expect("outers known to have at least one element");
189 let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
190 inner.column_types.push(scalar_type);
191 }
192 }
193 HirRelationExpr::Filter { predicates, input } => {
194 walk_relation(input, outers, simplify_join_on)?;
195 let mut outers = outers.to_vec();
196 outers.insert(0, input.typ(&outers, &NO_PARAMS));
197 for pred in predicates {
198 walk_scalar(pred, &outers, true, simplify_join_on)?;
199 }
200 }
201 HirRelationExpr::CallTable { exprs, .. } => {
202 let mut outers = outers.to_vec();
203 outers.insert(0, SqlRelationType::empty());
204 for scalar in exprs {
205 walk_scalar(scalar, &outers, false, simplify_join_on)?;
206 }
207 }
208 HirRelationExpr::Join {
209 left, right, on, ..
210 } => {
211 walk_relation(left, outers, simplify_join_on)?;
212 let left_type = left.typ(outers, &NO_PARAMS);
213 let mut outers = outers.to_vec();
214 outers.insert(0, left_type);
215 walk_relation(right, &outers, simplify_join_on)?;
216 if simplify_join_on {
217 // Build outers with the full join output type, since the
218 // ON clause can reference columns from both sides.
219 let right_type = right.typ(&outers, &NO_PARAMS);
220 let mut join_columns = outers[0].column_types.clone();
221 join_columns.extend(right_type.column_types);
222 outers[0] = SqlRelationType::new(join_columns);
223 walk_scalar(on, &outers, true, simplify_join_on)?;
224 }
225 }
226 expr => {
227 #[allow(deprecated)]
228 let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
229 walk_relation(expr, outers, simplify_join_on)
230 });
231 }
232 }
233 Ok(())
234 }
235
236 fn walk_scalar(
237 expr: &mut HirScalarExpr,
238 outers: &[SqlRelationType],
239 mut in_filter: bool,
240 simplify_join_on: bool,
241 ) -> Result<(), RecursionLimitError> {
242 expr.try_visit_mut_pre(&mut |e| {
243 match e {
244 HirScalarExpr::Exists(input, _name) => {
245 walk_relation(input, outers, simplify_join_on)?
246 }
247 HirScalarExpr::Select(input, _name) => {
248 walk_relation(input, outers, simplify_join_on)?;
249
250 // We're inside a `(SELECT ...)` subquery. Now let's see if
251 // it has the form `(SELECT <any|all>(...) FROM <input>)`.
252 // Ideally we could do this with one pattern, but Rust's pattern
253 // matching engine is not powerful enough, so we have to do this
254 // in stages; the early returns avoid brutal nesting.
255
256 let (func, expr, input) = match &mut **input {
257 HirRelationExpr::Reduce {
258 group_key,
259 aggregates,
260 input,
261 expected_group_size: _,
262 } if group_key.is_empty() && aggregates.len() == 1 => {
263 let agg = &mut aggregates[0];
264 (&agg.func, &mut agg.expr, input)
265 }
266 _ => return Ok(()),
267 };
268
269 if !in_filter && column_type(outers, input, expr).nullable {
270 // Unless we're directly inside a WHERE, this
271 // transformation is only valid if the expression involved
272 // is non-nullable.
273 return Ok(());
274 }
275
276 match func {
277 AggregateFunc::Any => {
278 // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
279 // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
280 *e = input.take().filter(vec![expr.take()]).exists();
281 }
282 AggregateFunc::All => {
283 // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
284 // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
285 //
286 // Note that negation of <expr> alone is insufficient.
287 // Consider that `WHERE <pred>` filters out rows if
288 // `<pred>` is false *or* null. To invert the test, we
289 // need `NOT <pred> OR <pred> IS NULL`.
290 let expr = expr.take();
291 let filter = expr.clone().not().or(expr.call_is_null());
292 *e = input.take().filter(vec![filter]).exists().not();
293 }
294 _ => (),
295 }
296 }
297 _ => {
298 // As soon as we see *any* scalar expression, we are no longer
299 // directly inside a filter.
300 in_filter = false;
301 }
302 }
303 Ok(())
304 })
305 }
306
307 walk_relation(expr, &[], simplify_join_on)
308}
309
310/// An empty parameter type map.
311///
312/// These transformations are expected to run after parameters are bound, so
313/// there is no need to provide any parameter type information.
314static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
315
316fn column_type(
317 outers: &[SqlRelationType],
318 inner: &HirRelationExpr,
319 expr: &HirScalarExpr,
320) -> SqlColumnType {
321 let inner_type = inner.typ(outers, &NO_PARAMS);
322 expr.typ(outers, &inner_type, &NO_PARAMS)
323}
324
325impl HirScalarExpr {
326 /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
327 /// considers column references that target the root level.
328 /// (See `visit_columns_referring_to_root_level`.)
329 fn support(&self) -> Vec<usize> {
330 let mut result = Vec::new();
331 self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
332 result
333 }
334
335 /// Changes column references in `self` by the given remapping.
336 /// Panics if a referred column is not present in `idx_map`!
337 fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
338 self.visit_columns_referring_to_root_level_mut(&mut |c| {
339 *c = idx_map[c];
340 });
341 self
342 }
343}
344
345/// # Aims and scope
346///
347/// The aim here is to amortize the overhead of the MIR window function pattern
348/// (see `window_func_applied_to`) by fusing groups of window function calls such
349/// that each group can be performed by one instance of the window function MIR
350/// pattern.
351///
352/// For now, we fuse only value window function calls and window aggregations.
353/// (We probably won't need to fuse scalar window functions for a long time.)
354///
355/// For now, we can fuse value window function calls and window aggregations where the
356/// A. partition by
357/// B. order by
358/// C. window frame
359/// D. ignore nulls for value window functions and distinct for window aggregations
360/// are all the same. (See `extract_options`.)
361/// (Later, we could improve this to only need A. to be the same. This would require
362/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
363/// TODO: As a much simpler intermediate step, at least we should ignore options that
364/// don't matter. For example, we should be able to fuse a `lag` that has a default
365/// frame with a `first_value` that has some custom frame, because `lag` is not
366/// affected by the frame.)
367/// Note that we fuse value window function calls and window aggregations separately.
368///
369/// # Implementation
370///
371/// At a high level, what we are going to do is look for Maps with more than one window function
372/// calls, and for each Map
373/// - remove some groups of window function call expressions from the Map's `scalars`;
374/// - insert a fused version of each group;
375/// - insert some expressions that decompose the results of the fused calls;
376/// - update some column references in `scalars`: those that refer to window function results that
377/// participated in fusion, as well as those that refer to columns that moved around due to
378/// removing and inserting expressions.
379/// - insert a Project above the matched Map to permute columns back to their original places.
380///
381/// It would be tempting to find groups simply by taking a list of all window function calls
382/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
383/// but a complication is that the possible groups that we could theoretically fuse overlap.
384/// This is because when forming groups we need to also take into account column references
385/// that point inside the same Map. For example, imagine a Map with the following scalar
386/// expressions:
387/// C1, E1, C2, C3, where
388/// - E1 refers to C1
389/// - C3 refers to E1.
390/// In this situation, we could either
391/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
392/// to it);
393/// - or fuse C2 and C3.
394/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
395/// no appropriate place for the fused expression: it would have to be both before and after E1.
396///
397/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
398/// we go through `scalars`, try to put each expression into each of our groups, and the first of
399/// these succeed. When trying to put an expression into a group, we need to be mindful about column
400/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
401/// sanity is that the fused version of each group will be inserted at the place where the first
402/// element of the group originally was. This means that the only condition that we need to check on
403/// column references when adding an expression to a group is that all column references in a group
404/// should be to columns that are earlier than the first element of the group. (No need to check
405/// column references in the other direction, i.e., references in other expressions that refer to
406/// columns in the group.)
407pub fn fuse_window_functions(
408 root: &mut HirRelationExpr,
409 _context: &crate::plan::lowering::Context,
410) -> Result<(), RecursionLimitError> {
411 /// Those options of a window function call that are relevant for fusion.
412 #[derive(PartialEq, Eq)]
413 enum WindowFuncCallOptions {
414 Value(ValueWindowFuncCallOptions),
415 Agg(AggregateWindowFuncCallOptions),
416 }
417 #[derive(PartialEq, Eq)]
418 struct ValueWindowFuncCallOptions {
419 partition_by: Vec<HirScalarExpr>,
420 outer_order_by: Vec<HirScalarExpr>,
421 inner_order_by: Vec<ColumnOrder>,
422 window_frame: WindowFrame,
423 ignore_nulls: bool,
424 }
425 #[derive(PartialEq, Eq)]
426 struct AggregateWindowFuncCallOptions {
427 partition_by: Vec<HirScalarExpr>,
428 outer_order_by: Vec<HirScalarExpr>,
429 inner_order_by: Vec<ColumnOrder>,
430 window_frame: WindowFrame,
431 distinct: bool,
432 }
433
434 /// Helper function to extract the above options.
435 fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
436 match call {
437 HirScalarExpr::Windowing(
438 WindowExpr {
439 func:
440 WindowExprType::Value(ValueWindowExpr {
441 order_by: inner_order_by,
442 window_frame,
443 ignore_nulls,
444 func: _,
445 args: _,
446 }),
447 partition_by,
448 order_by: outer_order_by,
449 },
450 _name,
451 ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
452 partition_by: partition_by.clone(),
453 outer_order_by: outer_order_by.clone(),
454 inner_order_by: inner_order_by.clone(),
455 window_frame: window_frame.clone(),
456 ignore_nulls: ignore_nulls.clone(),
457 }),
458 HirScalarExpr::Windowing(
459 WindowExpr {
460 func:
461 WindowExprType::Aggregate(AggregateWindowExpr {
462 aggregate_expr:
463 AggregateExpr {
464 distinct,
465 func: _,
466 expr: _,
467 },
468 order_by: inner_order_by,
469 window_frame,
470 }),
471 partition_by,
472 order_by: outer_order_by,
473 },
474 _name,
475 ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
476 partition_by: partition_by.clone(),
477 outer_order_by: outer_order_by.clone(),
478 inner_order_by: inner_order_by.clone(),
479 window_frame: window_frame.clone(),
480 distinct: distinct.clone(),
481 }),
482 _ => panic!(
483 "extract_options should only be called on value window functions or window aggregations"
484 ),
485 }
486 }
487
488 struct FusionGroup {
489 /// The original column index of the first element of the group. (This is an index into the
490 /// Map's `scalars` plus the arity of the Map's input.)
491 first_col: usize,
492 /// The options of all the window function calls in the group. (Must be the same for all the
493 /// calls.)
494 options: WindowFuncCallOptions,
495 /// The calls in the group, with their original column indexes.
496 calls: Vec<(usize, HirScalarExpr)>,
497 }
498
499 impl FusionGroup {
500 /// Creates a window function call that is a fused version of all the calls in the group.
501 /// `new_col` is the column index where the fused call will be inserted at.
502 fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
503 let fused = match self.options {
504 WindowFuncCallOptions::Value(options) => {
505 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
506 .calls
507 .iter()
508 .map(|(_idx, call)| {
509 if let HirScalarExpr::Windowing(
510 WindowExpr {
511 func:
512 WindowExprType::Value(ValueWindowExpr {
513 func,
514 args,
515 order_by: _,
516 window_frame: _,
517 ignore_nulls: _,
518 }),
519 partition_by: _,
520 order_by: _,
521 },
522 _name,
523 ) = call
524 {
525 (func.clone(), (**args).clone())
526 } else {
527 panic!("unknown window function in FusionGroup")
528 }
529 })
530 .unzip();
531 let fused_args = HirScalarExpr::call_variadic(
532 RecordCreate {
533 // These field names are not important, because this record will only be an
534 // intermediate expression, which we'll manipulate further before it ends up
535 // anywhere where a column name would be visible.
536 field_names: iter::repeat(ColumnName::from(""))
537 .take(fused_args.len())
538 .collect(),
539 },
540 fused_args,
541 );
542 HirScalarExpr::windowing(WindowExpr {
543 func: WindowExprType::Value(ValueWindowExpr {
544 func: ValueWindowFunc::Fused(fused_funcs),
545 args: Box::new(fused_args),
546 order_by: options.inner_order_by,
547 window_frame: options.window_frame,
548 ignore_nulls: options.ignore_nulls,
549 }),
550 partition_by: options.partition_by,
551 order_by: options.outer_order_by,
552 })
553 }
554 WindowFuncCallOptions::Agg(options) => {
555 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
556 .calls
557 .iter()
558 .map(|(_idx, call)| {
559 if let HirScalarExpr::Windowing(
560 WindowExpr {
561 func:
562 WindowExprType::Aggregate(AggregateWindowExpr {
563 aggregate_expr:
564 AggregateExpr {
565 func,
566 expr,
567 distinct: _,
568 },
569 order_by: _,
570 window_frame: _,
571 }),
572 partition_by: _,
573 order_by: _,
574 },
575 _name,
576 ) = call
577 {
578 (func.clone(), (**expr).clone())
579 } else {
580 panic!("unknown window function in FusionGroup")
581 }
582 })
583 .unzip();
584 let fused_args = HirScalarExpr::call_variadic(
585 RecordCreate {
586 field_names: iter::repeat(ColumnName::from(""))
587 .take(fused_args.len())
588 .collect(),
589 },
590 fused_args,
591 );
592 HirScalarExpr::windowing(WindowExpr {
593 func: WindowExprType::Aggregate(AggregateWindowExpr {
594 aggregate_expr: AggregateExpr {
595 func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
596 expr: Box::new(fused_args),
597 distinct: options.distinct,
598 },
599 order_by: options.inner_order_by,
600 window_frame: options.window_frame,
601 }),
602 partition_by: options.partition_by,
603 order_by: options.outer_order_by,
604 })
605 }
606 };
607
608 let decompositions = (0..self.calls.len())
609 .map(|field| {
610 HirScalarExpr::column(new_col)
611 .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
612 })
613 .collect();
614
615 (fused, decompositions)
616 }
617 }
618
619 let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
620 // Look for calls only at the root of scalar expressions. This is enough
621 // because they are always there, see 72e84bb78.
622 match scalar_expr {
623 HirScalarExpr::Windowing(
624 WindowExpr {
625 func: WindowExprType::Value(ValueWindowExpr { func, .. }),
626 ..
627 },
628 _name,
629 ) => {
630 // Exclude those calls that are already fused. (We shouldn't currently
631 // encounter these, because we just do one pass, but it's better to be
632 // robust against future code changes.)
633 !matches!(func, ValueWindowFunc::Fused(..))
634 }
635 HirScalarExpr::Windowing(
636 WindowExpr {
637 func:
638 WindowExprType::Aggregate(AggregateWindowExpr {
639 aggregate_expr: AggregateExpr { func, .. },
640 ..
641 }),
642 ..
643 },
644 _name,
645 ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
646 _ => false,
647 }
648 };
649
650 root.try_visit_mut_post(&mut |rel_expr| {
651 match rel_expr {
652 HirRelationExpr::Map { input, scalars } => {
653 // There will be various variable names involving `idx` or `col`:
654 // - `idx` will always be an index into `scalars` or something similar,
655 // - `col` will always be a column index,
656 // which is often `arity_before_map` + an index into `scalars`.
657 let arity_before_map = input.arity();
658 let orig_num_scalars = scalars.len();
659
660 // Collect all value window function calls and window aggregations with their column
661 // indexes.
662 let value_or_agg_window_func_calls = scalars
663 .iter()
664 .enumerate()
665 .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
666 .map(|(idx, call)| (idx + arity_before_map, call.clone()))
667 .collect_vec();
668 // Exit early if obviously no chance for fusion.
669 if value_or_agg_window_func_calls.len() <= 1 {
670 // Note that we are doing this only for performance. All plans should be exactly
671 // the same even if we comment out the following line.
672 return Ok(());
673 }
674
675 // Determine the fusion groups. (Each group will later be fused into one window
676 // function call.)
677 // Note that this has a quadratic run time with value_or_agg_window_func_calls in
678 // the worst case. However, this is fine even with 1000 window function calls.
679 let mut groups: Vec<FusionGroup> = Vec::new();
680 for (col, call) in value_or_agg_window_func_calls {
681 let options = extract_options(&call);
682 let support = call.support();
683 let to_fuse_with = groups
684 .iter_mut()
685 .filter(|group| {
686 group.options == options && support.iter().all(|c| *c < group.first_col)
687 })
688 .next();
689 if let Some(group) = to_fuse_with {
690 group.calls.push((col, call.clone()));
691 } else {
692 groups.push(FusionGroup {
693 first_col: col,
694 options,
695 calls: vec![(col, call.clone())],
696 });
697 }
698 }
699
700 // No fusion to do on groups of 1.
701 groups.retain(|g| g.calls.len() > 1);
702
703 let removals: BTreeSet<usize> = groups
704 .iter()
705 .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
706 .collect();
707
708 // Mutate `scalars`.
709 // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
710 // `groups` is already sorted by `first_col` due to the way it was constructed.)
711 // We also compute a remapping of old indexes to new indexes as we go.
712 let mut groups_it = groups.drain(..).peekable();
713 let mut group = groups_it.next();
714 let mut remap = BTreeMap::new();
715 remap.extend((0..arity_before_map).map(|col| (col, col)));
716 let mut new_col: usize = arity_before_map;
717 let mut new_scalars = Vec::new();
718 for (old_col, e) in scalars
719 .drain(..)
720 .enumerate()
721 .map(|(idx, e)| (idx + arity_before_map, e))
722 {
723 if group.as_ref().is_some_and(|g| g.first_col == old_col) {
724 // The current expression will be fused away, and a fused expression will
725 // appear in its place. Additionally, some new expressions will be inserted
726 // after the fused expression, to decompose the record that is the result of
727 // the fused call.
728 assert!(removals.contains(&old_col));
729 let group_unwrapped = group.expect("checked above");
730 let calls_cols = group_unwrapped
731 .calls
732 .iter()
733 .map(|(col, _call)| *col)
734 .collect_vec();
735 let (fused, decompositions) = group_unwrapped.fuse(new_col);
736 new_scalars.push(fused.remap(&remap));
737 new_scalars.extend(decompositions); // (no remapping needed)
738 new_col += 1;
739 for call_old_col in calls_cols {
740 let present = remap.insert(call_old_col, new_col);
741 assert!(present.is_none());
742 new_col += 1;
743 }
744 group = groups_it.next();
745 } else if removals.contains(&old_col) {
746 assert!(remap.contains_key(&old_col));
747 } else {
748 new_scalars.push(e.remap(&remap));
749 let present = remap.insert(old_col, new_col);
750 assert!(present.is_none());
751 new_col += 1;
752 }
753 }
754 *scalars = new_scalars;
755 assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
756
757 // Add a project to permute columns back to their original places.
758 *rel_expr = rel_expr.take().project(
759 (0..arity_before_map)
760 .chain((0..orig_num_scalars).map(|idx| {
761 *remap
762 .get(&(idx + arity_before_map))
763 .expect("all columns should be present by now")
764 }))
765 .collect(),
766 );
767
768 assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
769 }
770 _ => {}
771 }
772 Ok(())
773 })
774}