mz_sql/plan/transform_hir.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::visit::Visit;
19use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
20use mz_ore::stack::RecursionLimitError;
21use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
22
23use crate::plan::hir::{
24 AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
25 ValueWindowExpr, ValueWindowFunc, WindowExpr,
26};
27use crate::plan::{AggregateExpr, WindowExprType};
28
29/// Rewrites predicates that contain subqueries so that the subqueries
30/// appear in their own later predicate when possible.
31///
32/// For example, this function rewrites this expression
33///
34/// ```text
35/// Filter {
36/// predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
37/// }
38/// ```
39///
40/// like so:
41///
42/// ```text
43/// Filter {
44/// predicates: [
45/// a = b AND c = d,
46/// EXISTS (<subquery>),
47/// (<subquery 2>) = e,
48/// ]
49/// }
50/// ```
51///
52/// The rewrite causes decorrelation to incorporate prior predicates into
53/// the outer relation upon which the subquery is evaluated. In the above
54/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
55/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
56/// = e`, will be further restricted to outer rows that match `A = b AND c =
57/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
58/// subquery, especially when the original conjunction contains join keys.
59pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
60 fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61 #[allow(deprecated)]
62 expr.visit_mut_fallible(0, &mut |expr, _| {
63 match expr {
64 HirRelationExpr::Map { scalars, .. } => {
65 for scalar in scalars {
66 walk_scalar(scalar)?;
67 }
68 }
69 HirRelationExpr::CallTable { exprs, .. } => {
70 for expr in exprs {
71 walk_scalar(expr)?;
72 }
73 }
74 HirRelationExpr::Filter { predicates, .. } => {
75 let mut subqueries = vec![];
76 for predicate in &mut *predicates {
77 walk_scalar(predicate)?;
78 extract_conjuncted_subqueries(predicate, &mut subqueries)?;
79 }
80 // TODO(benesch): we could be smarter about the order in which
81 // we emit subqueries. At the moment we just emit in the order
82 // we discovered them, but ideally we'd emit them in an order
83 // that accounted for their cost/selectivity. E.g., low-cost,
84 // high-selectivity subqueries should go first.
85 for subquery in subqueries {
86 predicates.push(subquery);
87 }
88 }
89 _ => (),
90 }
91 Ok(())
92 })
93 }
94
95 fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
96 expr.try_visit_mut_post(&mut |expr| {
97 match expr {
98 HirScalarExpr::Exists(input, _name) | HirScalarExpr::Select(input, _name) => {
99 walk_relation(input)?
100 }
101 _ => (),
102 }
103 Ok(())
104 })
105 }
106
107 fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
108 let mut found = false;
109 expr.visit_pre(&mut |expr| match expr {
110 HirScalarExpr::Exists(..) | HirScalarExpr::Select(..) => found = true,
111 _ => (),
112 })?;
113 Ok(found)
114 }
115
116 /// Extracts subqueries from a conjunction into `out`.
117 ///
118 /// For example, given an expression like
119 ///
120 /// ```text
121 /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
122 /// ```
123 ///
124 /// this function rewrites the expression to
125 ///
126 /// ```text
127 /// a = b AND true AND c = d AND true
128 /// ```
129 ///
130 /// and returns the expression fragments `EXISTS (<subquery 1>)` and
131 /// `(<subquery 2>) = e` in the `out` vector.
132 fn extract_conjuncted_subqueries(
133 expr: &mut HirScalarExpr,
134 out: &mut Vec<HirScalarExpr>,
135 ) -> Result<(), RecursionLimitError> {
136 match expr {
137 HirScalarExpr::CallVariadic {
138 func: VariadicFunc::And,
139 exprs,
140 name: _,
141 } => {
142 exprs
143 .into_iter()
144 .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
145 }
146 expr if contains_subquery(expr)? => {
147 out.push(mem::replace(expr, HirScalarExpr::literal_true()))
148 }
149 _ => (),
150 }
151 Ok(())
152 }
153
154 walk_relation(expr)
155}
156
157/// Rewrites quantified comparisons into simpler EXISTS operators.
158///
159/// Note that this transformation is only valid when the expression is
160/// used in a context where the distinction between `FALSE` and `NULL`
161/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
162/// when the inputs to the comparison are non-nullable. This function is careful
163/// to only apply the transformation when it is valid to do so.
164///
165/// ```ignore
166/// WHERE (SELECT any(<pred>) FROM <rel>)
167/// =>
168/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
169///
170/// WHERE (SELECT all(<pred>) FROM <rel>)
171/// =>
172/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
173/// ```
174///
175/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
176/// M. Elhemali, et al.
177pub fn try_simplify_quantified_comparisons(
178 expr: &mut HirRelationExpr,
179) -> Result<(), RecursionLimitError> {
180 fn walk_relation(
181 expr: &mut HirRelationExpr,
182 outers: &[SqlRelationType],
183 ) -> Result<(), RecursionLimitError> {
184 match expr {
185 HirRelationExpr::Map { scalars, input } => {
186 walk_relation(input, outers)?;
187 let mut outers = outers.to_vec();
188 outers.insert(0, input.typ(&outers, &NO_PARAMS));
189 for scalar in scalars {
190 walk_scalar(scalar, &outers, false)?;
191 let (inner, outers) = outers
192 .split_first_mut()
193 .expect("outers known to have at least one element");
194 let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
195 inner.column_types.push(scalar_type);
196 }
197 }
198 HirRelationExpr::Filter { predicates, input } => {
199 walk_relation(input, outers)?;
200 let mut outers = outers.to_vec();
201 outers.insert(0, input.typ(&outers, &NO_PARAMS));
202 for pred in predicates {
203 walk_scalar(pred, &outers, true)?;
204 }
205 }
206 HirRelationExpr::CallTable { exprs, .. } => {
207 let mut outers = outers.to_vec();
208 outers.insert(0, SqlRelationType::empty());
209 for scalar in exprs {
210 walk_scalar(scalar, &outers, false)?;
211 }
212 }
213 HirRelationExpr::Join { left, right, .. } => {
214 walk_relation(left, outers)?;
215 let mut outers = outers.to_vec();
216 outers.insert(0, left.typ(&outers, &NO_PARAMS));
217 walk_relation(right, &outers)?;
218 }
219 expr => {
220 #[allow(deprecated)]
221 let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
222 walk_relation(expr, outers)
223 });
224 }
225 }
226 Ok(())
227 }
228
229 fn walk_scalar(
230 expr: &mut HirScalarExpr,
231 outers: &[SqlRelationType],
232 mut in_filter: bool,
233 ) -> Result<(), RecursionLimitError> {
234 expr.try_visit_mut_pre(&mut |e| {
235 match e {
236 HirScalarExpr::Exists(input, _name) => walk_relation(input, outers)?,
237 HirScalarExpr::Select(input, _name) => {
238 walk_relation(input, outers)?;
239
240 // We're inside a `(SELECT ...)` subquery. Now let's see if
241 // it has the form `(SELECT <any|all>(...) FROM <input>)`.
242 // Ideally we could do this with one pattern, but Rust's pattern
243 // matching engine is not powerful enough, so we have to do this
244 // in stages; the early returns avoid brutal nesting.
245
246 let (func, expr, input) = match &mut **input {
247 HirRelationExpr::Reduce {
248 group_key,
249 aggregates,
250 input,
251 expected_group_size: _,
252 } if group_key.is_empty() && aggregates.len() == 1 => {
253 let agg = &mut aggregates[0];
254 (&agg.func, &mut agg.expr, input)
255 }
256 _ => return Ok(()),
257 };
258
259 if !in_filter && column_type(outers, input, expr).nullable {
260 // Unless we're directly inside a WHERE, this
261 // transformation is only valid if the expression involved
262 // is non-nullable.
263 return Ok(());
264 }
265
266 match func {
267 AggregateFunc::Any => {
268 // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
269 // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
270 *e = input.take().filter(vec![expr.take()]).exists();
271 }
272 AggregateFunc::All => {
273 // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
274 // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
275 //
276 // Note that negation of <expr> alone is insufficient.
277 // Consider that `WHERE <pred>` filters out rows if
278 // `<pred>` is false *or* null. To invert the test, we
279 // need `NOT <pred> OR <pred> IS NULL`.
280 let expr = expr.take();
281 let filter = expr.clone().not().or(expr.call_is_null());
282 *e = input.take().filter(vec![filter]).exists().not();
283 }
284 _ => (),
285 }
286 }
287 _ => {
288 // As soon as we see *any* scalar expression, we are no longer
289 // directly inside a filter.
290 in_filter = false;
291 }
292 }
293 Ok(())
294 })
295 }
296
297 walk_relation(expr, &[])
298}
299
300/// An empty parameter type map.
301///
302/// These transformations are expected to run after parameters are bound, so
303/// there is no need to provide any parameter type information.
304static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
305
306fn column_type(
307 outers: &[SqlRelationType],
308 inner: &HirRelationExpr,
309 expr: &HirScalarExpr,
310) -> SqlColumnType {
311 let inner_type = inner.typ(outers, &NO_PARAMS);
312 expr.typ(outers, &inner_type, &NO_PARAMS)
313}
314
315impl HirScalarExpr {
316 /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
317 /// considers column references that target the root level.
318 /// (See `visit_columns_referring_to_root_level`.)
319 fn support(&self) -> Vec<usize> {
320 let mut result = Vec::new();
321 self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
322 result
323 }
324
325 /// Changes column references in `self` by the given remapping.
326 /// Panics if a referred column is not present in `idx_map`!
327 fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
328 self.visit_columns_referring_to_root_level_mut(&mut |c| {
329 *c = idx_map[c];
330 });
331 self
332 }
333}
334
335/// # Aims and scope
336///
337/// The aim here is to amortize the overhead of the MIR window function pattern
338/// (see `window_func_applied_to`) by fusing groups of window function calls such
339/// that each group can be performed by one instance of the window function MIR
340/// pattern.
341///
342/// For now, we fuse only value window function calls and window aggregations.
343/// (We probably won't need to fuse scalar window functions for a long time.)
344///
345/// For now, we can fuse value window function calls and window aggregations where the
346/// A. partition by
347/// B. order by
348/// C. window frame
349/// D. ignore nulls for value window functions and distinct for window aggregations
350/// are all the same. (See `extract_options`.)
351/// (Later, we could improve this to only need A. to be the same. This would require
352/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
353/// TODO: As a much simpler intermediate step, at least we should ignore options that
354/// don't matter. For example, we should be able to fuse a `lag` that has a default
355/// frame with a `first_value` that has some custom frame, because `lag` is not
356/// affected by the frame.)
357/// Note that we fuse value window function calls and window aggregations separately.
358///
359/// # Implementation
360///
361/// At a high level, what we are going to do is look for Maps with more than one window function
362/// calls, and for each Map
363/// - remove some groups of window function call expressions from the Map's `scalars`;
364/// - insert a fused version of each group;
365/// - insert some expressions that decompose the results of the fused calls;
366/// - update some column references in `scalars`: those that refer to window function results that
367/// participated in fusion, as well as those that refer to columns that moved around due to
368/// removing and inserting expressions.
369/// - insert a Project above the matched Map to permute columns back to their original places.
370///
371/// It would be tempting to find groups simply by taking a list of all window function calls
372/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
373/// but a complication is that the possible groups that we could theoretically fuse overlap.
374/// This is because when forming groups we need to also take into account column references
375/// that point inside the same Map. For example, imagine a Map with the following scalar
376/// expressions:
377/// C1, E1, C2, C3, where
378/// - E1 refers to C1
379/// - C3 refers to E1.
380/// In this situation, we could either
381/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
382/// to it);
383/// - or fuse C2 and C3.
384/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
385/// no appropriate place for the fused expression: it would have to be both before and after E1.
386///
387/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
388/// we go through `scalars`, try to put each expression into each of our groups, and the first of
389/// these succeed. When trying to put an expression into a group, we need to be mindful about column
390/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
391/// sanity is that the fused version of each group will be inserted at the place where the first
392/// element of the group originally was. This means that the only condition that we need to check on
393/// column references when adding an expression to a group is that all column references in a group
394/// should be to columns that are earlier than the first element of the group. (No need to check
395/// column references in the other direction, i.e., references in other expressions that refer to
396/// columns in the group.)
397pub fn fuse_window_functions(
398 root: &mut HirRelationExpr,
399 _context: &crate::plan::lowering::Context,
400) -> Result<(), RecursionLimitError> {
401 /// Those options of a window function call that are relevant for fusion.
402 #[derive(PartialEq, Eq)]
403 enum WindowFuncCallOptions {
404 Value(ValueWindowFuncCallOptions),
405 Agg(AggregateWindowFuncCallOptions),
406 }
407 #[derive(PartialEq, Eq)]
408 struct ValueWindowFuncCallOptions {
409 partition_by: Vec<HirScalarExpr>,
410 outer_order_by: Vec<HirScalarExpr>,
411 inner_order_by: Vec<ColumnOrder>,
412 window_frame: WindowFrame,
413 ignore_nulls: bool,
414 }
415 #[derive(PartialEq, Eq)]
416 struct AggregateWindowFuncCallOptions {
417 partition_by: Vec<HirScalarExpr>,
418 outer_order_by: Vec<HirScalarExpr>,
419 inner_order_by: Vec<ColumnOrder>,
420 window_frame: WindowFrame,
421 distinct: bool,
422 }
423
424 /// Helper function to extract the above options.
425 fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
426 match call {
427 HirScalarExpr::Windowing(
428 WindowExpr {
429 func:
430 WindowExprType::Value(ValueWindowExpr {
431 order_by: inner_order_by,
432 window_frame,
433 ignore_nulls,
434 func: _,
435 args: _,
436 }),
437 partition_by,
438 order_by: outer_order_by,
439 },
440 _name,
441 ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
442 partition_by: partition_by.clone(),
443 outer_order_by: outer_order_by.clone(),
444 inner_order_by: inner_order_by.clone(),
445 window_frame: window_frame.clone(),
446 ignore_nulls: ignore_nulls.clone(),
447 }),
448 HirScalarExpr::Windowing(
449 WindowExpr {
450 func:
451 WindowExprType::Aggregate(AggregateWindowExpr {
452 aggregate_expr:
453 AggregateExpr {
454 distinct,
455 func: _,
456 expr: _,
457 },
458 order_by: inner_order_by,
459 window_frame,
460 }),
461 partition_by,
462 order_by: outer_order_by,
463 },
464 _name,
465 ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
466 partition_by: partition_by.clone(),
467 outer_order_by: outer_order_by.clone(),
468 inner_order_by: inner_order_by.clone(),
469 window_frame: window_frame.clone(),
470 distinct: distinct.clone(),
471 }),
472 _ => panic!(
473 "extract_options should only be called on value window functions or window aggregations"
474 ),
475 }
476 }
477
478 struct FusionGroup {
479 /// The original column index of the first element of the group. (This is an index into the
480 /// Map's `scalars` plus the arity of the Map's input.)
481 first_col: usize,
482 /// The options of all the window function calls in the group. (Must be the same for all the
483 /// calls.)
484 options: WindowFuncCallOptions,
485 /// The calls in the group, with their original column indexes.
486 calls: Vec<(usize, HirScalarExpr)>,
487 }
488
489 impl FusionGroup {
490 /// Creates a window function call that is a fused version of all the calls in the group.
491 /// `new_col` is the column index where the fused call will be inserted at.
492 fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
493 let fused = match self.options {
494 WindowFuncCallOptions::Value(options) => {
495 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
496 .calls
497 .iter()
498 .map(|(_idx, call)| {
499 if let HirScalarExpr::Windowing(
500 WindowExpr {
501 func:
502 WindowExprType::Value(ValueWindowExpr {
503 func,
504 args,
505 order_by: _,
506 window_frame: _,
507 ignore_nulls: _,
508 }),
509 partition_by: _,
510 order_by: _,
511 },
512 _name,
513 ) = call
514 {
515 (func.clone(), (**args).clone())
516 } else {
517 panic!("unknown window function in FusionGroup")
518 }
519 })
520 .unzip();
521 let fused_args = HirScalarExpr::call_variadic(
522 VariadicFunc::RecordCreate {
523 // These field names are not important, because this record will only be an
524 // intermediate expression, which we'll manipulate further before it ends up
525 // anywhere where a column name would be visible.
526 field_names: iter::repeat(ColumnName::from(""))
527 .take(fused_args.len())
528 .collect(),
529 },
530 fused_args,
531 );
532 HirScalarExpr::windowing(WindowExpr {
533 func: WindowExprType::Value(ValueWindowExpr {
534 func: ValueWindowFunc::Fused(fused_funcs),
535 args: Box::new(fused_args),
536 order_by: options.inner_order_by,
537 window_frame: options.window_frame,
538 ignore_nulls: options.ignore_nulls,
539 }),
540 partition_by: options.partition_by,
541 order_by: options.outer_order_by,
542 })
543 }
544 WindowFuncCallOptions::Agg(options) => {
545 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
546 .calls
547 .iter()
548 .map(|(_idx, call)| {
549 if let HirScalarExpr::Windowing(
550 WindowExpr {
551 func:
552 WindowExprType::Aggregate(AggregateWindowExpr {
553 aggregate_expr:
554 AggregateExpr {
555 func,
556 expr,
557 distinct: _,
558 },
559 order_by: _,
560 window_frame: _,
561 }),
562 partition_by: _,
563 order_by: _,
564 },
565 _name,
566 ) = call
567 {
568 (func.clone(), (**expr).clone())
569 } else {
570 panic!("unknown window function in FusionGroup")
571 }
572 })
573 .unzip();
574 let fused_args = HirScalarExpr::call_variadic(
575 VariadicFunc::RecordCreate {
576 field_names: iter::repeat(ColumnName::from(""))
577 .take(fused_args.len())
578 .collect(),
579 },
580 fused_args,
581 );
582 HirScalarExpr::windowing(WindowExpr {
583 func: WindowExprType::Aggregate(AggregateWindowExpr {
584 aggregate_expr: AggregateExpr {
585 func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
586 expr: Box::new(fused_args),
587 distinct: options.distinct,
588 },
589 order_by: options.inner_order_by,
590 window_frame: options.window_frame,
591 }),
592 partition_by: options.partition_by,
593 order_by: options.outer_order_by,
594 })
595 }
596 };
597
598 let decompositions = (0..self.calls.len())
599 .map(|field| {
600 HirScalarExpr::column(new_col)
601 .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
602 })
603 .collect();
604
605 (fused, decompositions)
606 }
607 }
608
609 let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
610 // Look for calls only at the root of scalar expressions. This is enough
611 // because they are always there, see 72e84bb78.
612 match scalar_expr {
613 HirScalarExpr::Windowing(
614 WindowExpr {
615 func: WindowExprType::Value(ValueWindowExpr { func, .. }),
616 ..
617 },
618 _name,
619 ) => {
620 // Exclude those calls that are already fused. (We shouldn't currently
621 // encounter these, because we just do one pass, but it's better to be
622 // robust against future code changes.)
623 !matches!(func, ValueWindowFunc::Fused(..))
624 }
625 HirScalarExpr::Windowing(
626 WindowExpr {
627 func:
628 WindowExprType::Aggregate(AggregateWindowExpr {
629 aggregate_expr: AggregateExpr { func, .. },
630 ..
631 }),
632 ..
633 },
634 _name,
635 ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
636 _ => false,
637 }
638 };
639
640 root.try_visit_mut_post(&mut |rel_expr| {
641 match rel_expr {
642 HirRelationExpr::Map { input, scalars } => {
643 // There will be various variable names involving `idx` or `col`:
644 // - `idx` will always be an index into `scalars` or something similar,
645 // - `col` will always be a column index,
646 // which is often `arity_before_map` + an index into `scalars`.
647 let arity_before_map = input.arity();
648 let orig_num_scalars = scalars.len();
649
650 // Collect all value window function calls and window aggregations with their column
651 // indexes.
652 let value_or_agg_window_func_calls = scalars
653 .iter()
654 .enumerate()
655 .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
656 .map(|(idx, call)| (idx + arity_before_map, call.clone()))
657 .collect_vec();
658 // Exit early if obviously no chance for fusion.
659 if value_or_agg_window_func_calls.len() <= 1 {
660 // Note that we are doing this only for performance. All plans should be exactly
661 // the same even if we comment out the following line.
662 return Ok(());
663 }
664
665 // Determine the fusion groups. (Each group will later be fused into one window
666 // function call.)
667 // Note that this has a quadratic run time with value_or_agg_window_func_calls in
668 // the worst case. However, this is fine even with 1000 window function calls.
669 let mut groups: Vec<FusionGroup> = Vec::new();
670 for (col, call) in value_or_agg_window_func_calls {
671 let options = extract_options(&call);
672 let support = call.support();
673 let to_fuse_with = groups
674 .iter_mut()
675 .filter(|group| {
676 group.options == options && support.iter().all(|c| *c < group.first_col)
677 })
678 .next();
679 if let Some(group) = to_fuse_with {
680 group.calls.push((col, call.clone()));
681 } else {
682 groups.push(FusionGroup {
683 first_col: col,
684 options,
685 calls: vec![(col, call.clone())],
686 });
687 }
688 }
689
690 // No fusion to do on groups of 1.
691 groups.retain(|g| g.calls.len() > 1);
692
693 let removals: BTreeSet<usize> = groups
694 .iter()
695 .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
696 .collect();
697
698 // Mutate `scalars`.
699 // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
700 // `groups` is already sorted by `first_col` due to the way it was constructed.)
701 // We also compute a remapping of old indexes to new indexes as we go.
702 let mut groups_it = groups.drain(..).peekable();
703 let mut group = groups_it.next();
704 let mut remap = BTreeMap::new();
705 remap.extend((0..arity_before_map).map(|col| (col, col)));
706 let mut new_col: usize = arity_before_map;
707 let mut new_scalars = Vec::new();
708 for (old_col, e) in scalars
709 .drain(..)
710 .enumerate()
711 .map(|(idx, e)| (idx + arity_before_map, e))
712 {
713 if group.as_ref().is_some_and(|g| g.first_col == old_col) {
714 // The current expression will be fused away, and a fused expression will
715 // appear in its place. Additionally, some new expressions will be inserted
716 // after the fused expression, to decompose the record that is the result of
717 // the fused call.
718 assert!(removals.contains(&old_col));
719 let group_unwrapped = group.expect("checked above");
720 let calls_cols = group_unwrapped
721 .calls
722 .iter()
723 .map(|(col, _call)| *col)
724 .collect_vec();
725 let (fused, decompositions) = group_unwrapped.fuse(new_col);
726 new_scalars.push(fused.remap(&remap));
727 new_scalars.extend(decompositions); // (no remapping needed)
728 new_col += 1;
729 for call_old_col in calls_cols {
730 let present = remap.insert(call_old_col, new_col);
731 assert!(present.is_none());
732 new_col += 1;
733 }
734 group = groups_it.next();
735 } else if removals.contains(&old_col) {
736 assert!(remap.contains_key(&old_col));
737 } else {
738 new_scalars.push(e.remap(&remap));
739 let present = remap.insert(old_col, new_col);
740 assert!(present.is_none());
741 new_col += 1;
742 }
743 }
744 *scalars = new_scalars;
745 assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
746
747 // Add a project to permute columns back to their original places.
748 *rel_expr = rel_expr.take().project(
749 (0..arity_before_map)
750 .chain((0..orig_num_scalars).map(|idx| {
751 *remap
752 .get(&(idx + arity_before_map))
753 .expect("all columns should be present by now")
754 }))
755 .collect(),
756 );
757
758 assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
759 }
760 _ => {}
761 }
762 Ok(())
763 })
764}