mz_sql/plan/transform_hir.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::func::variadic::RecordCreate;
19use mz_expr::visit::Visit;
20use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
21use mz_ore::stack::RecursionLimitError;
22use mz_repr::{ColumnName, SqlColumnType, SqlRelationType, SqlScalarType};
23
24use crate::plan::hir::{
25 AbstractExpr, AggregateFunc, AggregateWindowExpr, HirRelationExpr, HirScalarExpr,
26 ValueWindowExpr, ValueWindowFunc, WindowExpr,
27};
28use crate::plan::{AggregateExpr, WindowExprType};
29
30/// Rewrites predicates that contain subqueries so that the subqueries
31/// appear in their own later predicate when possible.
32///
33/// For example, this function rewrites this expression
34///
35/// ```text
36/// Filter {
37/// predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
38/// }
39/// ```
40///
41/// like so:
42///
43/// ```text
44/// Filter {
45/// predicates: [
46/// a = b AND c = d,
47/// EXISTS (<subquery>),
48/// (<subquery 2>) = e,
49/// ]
50/// }
51/// ```
52///
53/// The rewrite causes decorrelation to incorporate prior predicates into
54/// the outer relation upon which the subquery is evaluated. In the above
55/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
56/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
57/// = e`, will be further restricted to outer rows that match `A = b AND c =
58/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
59/// subquery, especially when the original conjunction contains join keys.
60pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61 fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
62 #[allow(deprecated)]
63 expr.visit_mut_fallible(0, &mut |expr, _| {
64 match expr {
65 HirRelationExpr::Map { scalars, .. } => {
66 for scalar in scalars {
67 walk_scalar(scalar)?;
68 }
69 }
70 HirRelationExpr::CallTable { exprs, .. } => {
71 for expr in exprs {
72 walk_scalar(expr)?;
73 }
74 }
75 HirRelationExpr::Filter { predicates, .. } => {
76 let mut subqueries = vec![];
77 for predicate in &mut *predicates {
78 walk_scalar(predicate)?;
79 extract_conjuncted_subqueries(predicate, &mut subqueries)?;
80 }
81 // TODO(benesch): we could be smarter about the order in which
82 // we emit subqueries. At the moment we just emit in the order
83 // we discovered them, but ideally we'd emit them in an order
84 // that accounted for their cost/selectivity. E.g., low-cost,
85 // high-selectivity subqueries should go first.
86 for subquery in subqueries {
87 predicates.push(subquery);
88 }
89 }
90 _ => (),
91 }
92 Ok(())
93 })
94 }
95
96 fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
97 expr.try_visit_mut_post(&mut |expr| {
98 match expr {
99 HirScalarExpr::Exists(input, _name) | HirScalarExpr::Select(input, _name) => {
100 walk_relation(input)?
101 }
102 _ => (),
103 }
104 Ok(())
105 })
106 }
107
108 fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
109 let mut found = false;
110 expr.visit_pre(&mut |expr| match expr {
111 HirScalarExpr::Exists(..) | HirScalarExpr::Select(..) => found = true,
112 _ => (),
113 })?;
114 Ok(found)
115 }
116
117 /// Extracts subqueries from a conjunction into `out`.
118 ///
119 /// For example, given an expression like
120 ///
121 /// ```text
122 /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
123 /// ```
124 ///
125 /// this function rewrites the expression to
126 ///
127 /// ```text
128 /// a = b AND true AND c = d AND true
129 /// ```
130 ///
131 /// and returns the expression fragments `EXISTS (<subquery 1>)` and
132 /// `(<subquery 2>) = e` in the `out` vector.
133 fn extract_conjuncted_subqueries(
134 expr: &mut HirScalarExpr,
135 out: &mut Vec<HirScalarExpr>,
136 ) -> Result<(), RecursionLimitError> {
137 match expr {
138 HirScalarExpr::CallVariadic {
139 func: VariadicFunc::And(_),
140 exprs,
141 name: _,
142 } => {
143 exprs
144 .into_iter()
145 .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
146 }
147 expr if contains_subquery(expr)? => {
148 out.push(mem::replace(expr, HirScalarExpr::literal_true()))
149 }
150 _ => (),
151 }
152 Ok(())
153 }
154
155 walk_relation(expr)
156}
157
158/// Rewrites quantified comparisons into simpler EXISTS operators.
159///
160/// Note that this transformation is only valid when the expression is
161/// used in a context where the distinction between `FALSE` and `NULL`
162/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
163/// when the inputs to the comparison are non-nullable. This function is careful
164/// to only apply the transformation when it is valid to do so.
165///
166/// ```ignore
167/// WHERE (SELECT any(<pred>) FROM <rel>)
168/// =>
169/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
170///
171/// WHERE (SELECT all(<pred>) FROM <rel>)
172/// =>
173/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
174/// ```
175///
176/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
177/// M. Elhemali, et al.
178pub fn try_simplify_quantified_comparisons(
179 expr: &mut HirRelationExpr,
180) -> Result<(), RecursionLimitError> {
181 fn walk_relation(
182 expr: &mut HirRelationExpr,
183 outers: &[SqlRelationType],
184 ) -> Result<(), RecursionLimitError> {
185 match expr {
186 HirRelationExpr::Map { scalars, input } => {
187 walk_relation(input, outers)?;
188 let mut outers = outers.to_vec();
189 outers.insert(0, input.typ(&outers, &NO_PARAMS));
190 for scalar in scalars {
191 walk_scalar(scalar, &outers, false)?;
192 let (inner, outers) = outers
193 .split_first_mut()
194 .expect("outers known to have at least one element");
195 let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
196 inner.column_types.push(scalar_type);
197 }
198 }
199 HirRelationExpr::Filter { predicates, input } => {
200 walk_relation(input, outers)?;
201 let mut outers = outers.to_vec();
202 outers.insert(0, input.typ(&outers, &NO_PARAMS));
203 for pred in predicates {
204 walk_scalar(pred, &outers, true)?;
205 }
206 }
207 HirRelationExpr::CallTable { exprs, .. } => {
208 let mut outers = outers.to_vec();
209 outers.insert(0, SqlRelationType::empty());
210 for scalar in exprs {
211 walk_scalar(scalar, &outers, false)?;
212 }
213 }
214 HirRelationExpr::Join { left, right, .. } => {
215 walk_relation(left, outers)?;
216 let mut outers = outers.to_vec();
217 outers.insert(0, left.typ(&outers, &NO_PARAMS));
218 walk_relation(right, &outers)?;
219 }
220 expr => {
221 #[allow(deprecated)]
222 let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
223 walk_relation(expr, outers)
224 });
225 }
226 }
227 Ok(())
228 }
229
230 fn walk_scalar(
231 expr: &mut HirScalarExpr,
232 outers: &[SqlRelationType],
233 mut in_filter: bool,
234 ) -> Result<(), RecursionLimitError> {
235 expr.try_visit_mut_pre(&mut |e| {
236 match e {
237 HirScalarExpr::Exists(input, _name) => walk_relation(input, outers)?,
238 HirScalarExpr::Select(input, _name) => {
239 walk_relation(input, outers)?;
240
241 // We're inside a `(SELECT ...)` subquery. Now let's see if
242 // it has the form `(SELECT <any|all>(...) FROM <input>)`.
243 // Ideally we could do this with one pattern, but Rust's pattern
244 // matching engine is not powerful enough, so we have to do this
245 // in stages; the early returns avoid brutal nesting.
246
247 let (func, expr, input) = match &mut **input {
248 HirRelationExpr::Reduce {
249 group_key,
250 aggregates,
251 input,
252 expected_group_size: _,
253 } if group_key.is_empty() && aggregates.len() == 1 => {
254 let agg = &mut aggregates[0];
255 (&agg.func, &mut agg.expr, input)
256 }
257 _ => return Ok(()),
258 };
259
260 if !in_filter && column_type(outers, input, expr).nullable {
261 // Unless we're directly inside a WHERE, this
262 // transformation is only valid if the expression involved
263 // is non-nullable.
264 return Ok(());
265 }
266
267 match func {
268 AggregateFunc::Any => {
269 // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
270 // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
271 *e = input.take().filter(vec![expr.take()]).exists();
272 }
273 AggregateFunc::All => {
274 // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
275 // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
276 //
277 // Note that negation of <expr> alone is insufficient.
278 // Consider that `WHERE <pred>` filters out rows if
279 // `<pred>` is false *or* null. To invert the test, we
280 // need `NOT <pred> OR <pred> IS NULL`.
281 let expr = expr.take();
282 let filter = expr.clone().not().or(expr.call_is_null());
283 *e = input.take().filter(vec![filter]).exists().not();
284 }
285 _ => (),
286 }
287 }
288 _ => {
289 // As soon as we see *any* scalar expression, we are no longer
290 // directly inside a filter.
291 in_filter = false;
292 }
293 }
294 Ok(())
295 })
296 }
297
298 walk_relation(expr, &[])
299}
300
301/// An empty parameter type map.
302///
303/// These transformations are expected to run after parameters are bound, so
304/// there is no need to provide any parameter type information.
305static NO_PARAMS: LazyLock<BTreeMap<usize, SqlScalarType>> = LazyLock::new(BTreeMap::new);
306
307fn column_type(
308 outers: &[SqlRelationType],
309 inner: &HirRelationExpr,
310 expr: &HirScalarExpr,
311) -> SqlColumnType {
312 let inner_type = inner.typ(outers, &NO_PARAMS);
313 expr.typ(outers, &inner_type, &NO_PARAMS)
314}
315
316impl HirScalarExpr {
317 /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
318 /// considers column references that target the root level.
319 /// (See `visit_columns_referring_to_root_level`.)
320 fn support(&self) -> Vec<usize> {
321 let mut result = Vec::new();
322 self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
323 result
324 }
325
326 /// Changes column references in `self` by the given remapping.
327 /// Panics if a referred column is not present in `idx_map`!
328 fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
329 self.visit_columns_referring_to_root_level_mut(&mut |c| {
330 *c = idx_map[c];
331 });
332 self
333 }
334}
335
336/// # Aims and scope
337///
338/// The aim here is to amortize the overhead of the MIR window function pattern
339/// (see `window_func_applied_to`) by fusing groups of window function calls such
340/// that each group can be performed by one instance of the window function MIR
341/// pattern.
342///
343/// For now, we fuse only value window function calls and window aggregations.
344/// (We probably won't need to fuse scalar window functions for a long time.)
345///
346/// For now, we can fuse value window function calls and window aggregations where the
347/// A. partition by
348/// B. order by
349/// C. window frame
350/// D. ignore nulls for value window functions and distinct for window aggregations
351/// are all the same. (See `extract_options`.)
352/// (Later, we could improve this to only need A. to be the same. This would require
353/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
354/// TODO: As a much simpler intermediate step, at least we should ignore options that
355/// don't matter. For example, we should be able to fuse a `lag` that has a default
356/// frame with a `first_value` that has some custom frame, because `lag` is not
357/// affected by the frame.)
358/// Note that we fuse value window function calls and window aggregations separately.
359///
360/// # Implementation
361///
362/// At a high level, what we are going to do is look for Maps with more than one window function
363/// calls, and for each Map
364/// - remove some groups of window function call expressions from the Map's `scalars`;
365/// - insert a fused version of each group;
366/// - insert some expressions that decompose the results of the fused calls;
367/// - update some column references in `scalars`: those that refer to window function results that
368/// participated in fusion, as well as those that refer to columns that moved around due to
369/// removing and inserting expressions.
370/// - insert a Project above the matched Map to permute columns back to their original places.
371///
372/// It would be tempting to find groups simply by taking a list of all window function calls
373/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
374/// but a complication is that the possible groups that we could theoretically fuse overlap.
375/// This is because when forming groups we need to also take into account column references
376/// that point inside the same Map. For example, imagine a Map with the following scalar
377/// expressions:
378/// C1, E1, C2, C3, where
379/// - E1 refers to C1
380/// - C3 refers to E1.
381/// In this situation, we could either
382/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
383/// to it);
384/// - or fuse C2 and C3.
385/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
386/// no appropriate place for the fused expression: it would have to be both before and after E1.
387///
388/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
389/// we go through `scalars`, try to put each expression into each of our groups, and the first of
390/// these succeed. When trying to put an expression into a group, we need to be mindful about column
391/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
392/// sanity is that the fused version of each group will be inserted at the place where the first
393/// element of the group originally was. This means that the only condition that we need to check on
394/// column references when adding an expression to a group is that all column references in a group
395/// should be to columns that are earlier than the first element of the group. (No need to check
396/// column references in the other direction, i.e., references in other expressions that refer to
397/// columns in the group.)
398pub fn fuse_window_functions(
399 root: &mut HirRelationExpr,
400 _context: &crate::plan::lowering::Context,
401) -> Result<(), RecursionLimitError> {
402 /// Those options of a window function call that are relevant for fusion.
403 #[derive(PartialEq, Eq)]
404 enum WindowFuncCallOptions {
405 Value(ValueWindowFuncCallOptions),
406 Agg(AggregateWindowFuncCallOptions),
407 }
408 #[derive(PartialEq, Eq)]
409 struct ValueWindowFuncCallOptions {
410 partition_by: Vec<HirScalarExpr>,
411 outer_order_by: Vec<HirScalarExpr>,
412 inner_order_by: Vec<ColumnOrder>,
413 window_frame: WindowFrame,
414 ignore_nulls: bool,
415 }
416 #[derive(PartialEq, Eq)]
417 struct AggregateWindowFuncCallOptions {
418 partition_by: Vec<HirScalarExpr>,
419 outer_order_by: Vec<HirScalarExpr>,
420 inner_order_by: Vec<ColumnOrder>,
421 window_frame: WindowFrame,
422 distinct: bool,
423 }
424
425 /// Helper function to extract the above options.
426 fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
427 match call {
428 HirScalarExpr::Windowing(
429 WindowExpr {
430 func:
431 WindowExprType::Value(ValueWindowExpr {
432 order_by: inner_order_by,
433 window_frame,
434 ignore_nulls,
435 func: _,
436 args: _,
437 }),
438 partition_by,
439 order_by: outer_order_by,
440 },
441 _name,
442 ) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
443 partition_by: partition_by.clone(),
444 outer_order_by: outer_order_by.clone(),
445 inner_order_by: inner_order_by.clone(),
446 window_frame: window_frame.clone(),
447 ignore_nulls: ignore_nulls.clone(),
448 }),
449 HirScalarExpr::Windowing(
450 WindowExpr {
451 func:
452 WindowExprType::Aggregate(AggregateWindowExpr {
453 aggregate_expr:
454 AggregateExpr {
455 distinct,
456 func: _,
457 expr: _,
458 },
459 order_by: inner_order_by,
460 window_frame,
461 }),
462 partition_by,
463 order_by: outer_order_by,
464 },
465 _name,
466 ) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
467 partition_by: partition_by.clone(),
468 outer_order_by: outer_order_by.clone(),
469 inner_order_by: inner_order_by.clone(),
470 window_frame: window_frame.clone(),
471 distinct: distinct.clone(),
472 }),
473 _ => panic!(
474 "extract_options should only be called on value window functions or window aggregations"
475 ),
476 }
477 }
478
479 struct FusionGroup {
480 /// The original column index of the first element of the group. (This is an index into the
481 /// Map's `scalars` plus the arity of the Map's input.)
482 first_col: usize,
483 /// The options of all the window function calls in the group. (Must be the same for all the
484 /// calls.)
485 options: WindowFuncCallOptions,
486 /// The calls in the group, with their original column indexes.
487 calls: Vec<(usize, HirScalarExpr)>,
488 }
489
490 impl FusionGroup {
491 /// Creates a window function call that is a fused version of all the calls in the group.
492 /// `new_col` is the column index where the fused call will be inserted at.
493 fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
494 let fused = match self.options {
495 WindowFuncCallOptions::Value(options) => {
496 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
497 .calls
498 .iter()
499 .map(|(_idx, call)| {
500 if let HirScalarExpr::Windowing(
501 WindowExpr {
502 func:
503 WindowExprType::Value(ValueWindowExpr {
504 func,
505 args,
506 order_by: _,
507 window_frame: _,
508 ignore_nulls: _,
509 }),
510 partition_by: _,
511 order_by: _,
512 },
513 _name,
514 ) = call
515 {
516 (func.clone(), (**args).clone())
517 } else {
518 panic!("unknown window function in FusionGroup")
519 }
520 })
521 .unzip();
522 let fused_args = HirScalarExpr::call_variadic(
523 RecordCreate {
524 // These field names are not important, because this record will only be an
525 // intermediate expression, which we'll manipulate further before it ends up
526 // anywhere where a column name would be visible.
527 field_names: iter::repeat(ColumnName::from(""))
528 .take(fused_args.len())
529 .collect(),
530 },
531 fused_args,
532 );
533 HirScalarExpr::windowing(WindowExpr {
534 func: WindowExprType::Value(ValueWindowExpr {
535 func: ValueWindowFunc::Fused(fused_funcs),
536 args: Box::new(fused_args),
537 order_by: options.inner_order_by,
538 window_frame: options.window_frame,
539 ignore_nulls: options.ignore_nulls,
540 }),
541 partition_by: options.partition_by,
542 order_by: options.outer_order_by,
543 })
544 }
545 WindowFuncCallOptions::Agg(options) => {
546 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
547 .calls
548 .iter()
549 .map(|(_idx, call)| {
550 if let HirScalarExpr::Windowing(
551 WindowExpr {
552 func:
553 WindowExprType::Aggregate(AggregateWindowExpr {
554 aggregate_expr:
555 AggregateExpr {
556 func,
557 expr,
558 distinct: _,
559 },
560 order_by: _,
561 window_frame: _,
562 }),
563 partition_by: _,
564 order_by: _,
565 },
566 _name,
567 ) = call
568 {
569 (func.clone(), (**expr).clone())
570 } else {
571 panic!("unknown window function in FusionGroup")
572 }
573 })
574 .unzip();
575 let fused_args = HirScalarExpr::call_variadic(
576 RecordCreate {
577 field_names: iter::repeat(ColumnName::from(""))
578 .take(fused_args.len())
579 .collect(),
580 },
581 fused_args,
582 );
583 HirScalarExpr::windowing(WindowExpr {
584 func: WindowExprType::Aggregate(AggregateWindowExpr {
585 aggregate_expr: AggregateExpr {
586 func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
587 expr: Box::new(fused_args),
588 distinct: options.distinct,
589 },
590 order_by: options.inner_order_by,
591 window_frame: options.window_frame,
592 }),
593 partition_by: options.partition_by,
594 order_by: options.outer_order_by,
595 })
596 }
597 };
598
599 let decompositions = (0..self.calls.len())
600 .map(|field| {
601 HirScalarExpr::column(new_col)
602 .call_unary(UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)))
603 })
604 .collect();
605
606 (fused, decompositions)
607 }
608 }
609
610 let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
611 // Look for calls only at the root of scalar expressions. This is enough
612 // because they are always there, see 72e84bb78.
613 match scalar_expr {
614 HirScalarExpr::Windowing(
615 WindowExpr {
616 func: WindowExprType::Value(ValueWindowExpr { func, .. }),
617 ..
618 },
619 _name,
620 ) => {
621 // Exclude those calls that are already fused. (We shouldn't currently
622 // encounter these, because we just do one pass, but it's better to be
623 // robust against future code changes.)
624 !matches!(func, ValueWindowFunc::Fused(..))
625 }
626 HirScalarExpr::Windowing(
627 WindowExpr {
628 func:
629 WindowExprType::Aggregate(AggregateWindowExpr {
630 aggregate_expr: AggregateExpr { func, .. },
631 ..
632 }),
633 ..
634 },
635 _name,
636 ) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
637 _ => false,
638 }
639 };
640
641 root.try_visit_mut_post(&mut |rel_expr| {
642 match rel_expr {
643 HirRelationExpr::Map { input, scalars } => {
644 // There will be various variable names involving `idx` or `col`:
645 // - `idx` will always be an index into `scalars` or something similar,
646 // - `col` will always be a column index,
647 // which is often `arity_before_map` + an index into `scalars`.
648 let arity_before_map = input.arity();
649 let orig_num_scalars = scalars.len();
650
651 // Collect all value window function calls and window aggregations with their column
652 // indexes.
653 let value_or_agg_window_func_calls = scalars
654 .iter()
655 .enumerate()
656 .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
657 .map(|(idx, call)| (idx + arity_before_map, call.clone()))
658 .collect_vec();
659 // Exit early if obviously no chance for fusion.
660 if value_or_agg_window_func_calls.len() <= 1 {
661 // Note that we are doing this only for performance. All plans should be exactly
662 // the same even if we comment out the following line.
663 return Ok(());
664 }
665
666 // Determine the fusion groups. (Each group will later be fused into one window
667 // function call.)
668 // Note that this has a quadratic run time with value_or_agg_window_func_calls in
669 // the worst case. However, this is fine even with 1000 window function calls.
670 let mut groups: Vec<FusionGroup> = Vec::new();
671 for (col, call) in value_or_agg_window_func_calls {
672 let options = extract_options(&call);
673 let support = call.support();
674 let to_fuse_with = groups
675 .iter_mut()
676 .filter(|group| {
677 group.options == options && support.iter().all(|c| *c < group.first_col)
678 })
679 .next();
680 if let Some(group) = to_fuse_with {
681 group.calls.push((col, call.clone()));
682 } else {
683 groups.push(FusionGroup {
684 first_col: col,
685 options,
686 calls: vec![(col, call.clone())],
687 });
688 }
689 }
690
691 // No fusion to do on groups of 1.
692 groups.retain(|g| g.calls.len() > 1);
693
694 let removals: BTreeSet<usize> = groups
695 .iter()
696 .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
697 .collect();
698
699 // Mutate `scalars`.
700 // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
701 // `groups` is already sorted by `first_col` due to the way it was constructed.)
702 // We also compute a remapping of old indexes to new indexes as we go.
703 let mut groups_it = groups.drain(..).peekable();
704 let mut group = groups_it.next();
705 let mut remap = BTreeMap::new();
706 remap.extend((0..arity_before_map).map(|col| (col, col)));
707 let mut new_col: usize = arity_before_map;
708 let mut new_scalars = Vec::new();
709 for (old_col, e) in scalars
710 .drain(..)
711 .enumerate()
712 .map(|(idx, e)| (idx + arity_before_map, e))
713 {
714 if group.as_ref().is_some_and(|g| g.first_col == old_col) {
715 // The current expression will be fused away, and a fused expression will
716 // appear in its place. Additionally, some new expressions will be inserted
717 // after the fused expression, to decompose the record that is the result of
718 // the fused call.
719 assert!(removals.contains(&old_col));
720 let group_unwrapped = group.expect("checked above");
721 let calls_cols = group_unwrapped
722 .calls
723 .iter()
724 .map(|(col, _call)| *col)
725 .collect_vec();
726 let (fused, decompositions) = group_unwrapped.fuse(new_col);
727 new_scalars.push(fused.remap(&remap));
728 new_scalars.extend(decompositions); // (no remapping needed)
729 new_col += 1;
730 for call_old_col in calls_cols {
731 let present = remap.insert(call_old_col, new_col);
732 assert!(present.is_none());
733 new_col += 1;
734 }
735 group = groups_it.next();
736 } else if removals.contains(&old_col) {
737 assert!(remap.contains_key(&old_col));
738 } else {
739 new_scalars.push(e.remap(&remap));
740 let present = remap.insert(old_col, new_col);
741 assert!(present.is_none());
742 new_col += 1;
743 }
744 }
745 *scalars = new_scalars;
746 assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
747
748 // Add a project to permute columns back to their original places.
749 *rel_expr = rel_expr.take().project(
750 (0..arity_before_map)
751 .chain((0..orig_num_scalars).map(|idx| {
752 *remap
753 .get(&(idx + arity_before_map))
754 .expect("all columns should be present by now")
755 }))
756 .collect(),
757 );
758
759 assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
760 }
761 _ => {}
762 }
763 Ok(())
764 })
765}