mz_sql/plan/transform_hir.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations of SQL IR, before decorrelation.
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::sync::LazyLock;
14use std::{iter, mem};
15
16use itertools::Itertools;
17use mz_expr::WindowFrame;
18use mz_expr::visit::Visit;
19use mz_expr::{ColumnOrder, UnaryFunc, VariadicFunc};
20use mz_ore::stack::RecursionLimitError;
21use mz_repr::{ColumnName, ColumnType, RelationType, ScalarType};
22
23use crate::plan::hir::{
24 AbstractExpr, AggregateFunc, AggregateWindowExpr, ColumnRef, HirRelationExpr, HirScalarExpr,
25 ValueWindowExpr, ValueWindowFunc, WindowExpr,
26};
27use crate::plan::{AggregateExpr, WindowExprType};
28
29/// Rewrites predicates that contain subqueries so that the subqueries
30/// appear in their own later predicate when possible.
31///
32/// For example, this function rewrites this expression
33///
34/// ```text
35/// Filter {
36/// predicates: [a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e]
37/// }
38/// ```
39///
40/// like so:
41///
42/// ```text
43/// Filter {
44/// predicates: [
45/// a = b AND c = d,
46/// EXISTS (<subquery>),
47/// (<subquery 2>) = e,
48/// ]
49/// }
50/// ```
51///
52/// The rewrite causes decorrelation to incorporate prior predicates into
53/// the outer relation upon which the subquery is evaluated. In the above
54/// rewritten example, the `EXISTS (<subquery>)` will only be evaluated for
55/// outer rows where `a = b AND c = d`. The second subquery, `(<subquery 2>)
56/// = e`, will be further restricted to outer rows that match `A = b AND c =
57/// d AND EXISTS(<subquery>)`. This can vastly reduce the cost of the
58/// subquery, especially when the original conjunction contains join keys.
59pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
60 fn walk_relation(expr: &mut HirRelationExpr) -> Result<(), RecursionLimitError> {
61 #[allow(deprecated)]
62 expr.visit_mut_fallible(0, &mut |expr, _| {
63 match expr {
64 HirRelationExpr::Map { scalars, .. } => {
65 for scalar in scalars {
66 walk_scalar(scalar)?;
67 }
68 }
69 HirRelationExpr::CallTable { exprs, .. } => {
70 for expr in exprs {
71 walk_scalar(expr)?;
72 }
73 }
74 HirRelationExpr::Filter { predicates, .. } => {
75 let mut subqueries = vec![];
76 for predicate in &mut *predicates {
77 walk_scalar(predicate)?;
78 extract_conjuncted_subqueries(predicate, &mut subqueries)?;
79 }
80 // TODO(benesch): we could be smarter about the order in which
81 // we emit subqueries. At the moment we just emit in the order
82 // we discovered them, but ideally we'd emit them in an order
83 // that accounted for their cost/selectivity. E.g., low-cost,
84 // high-selectivity subqueries should go first.
85 for subquery in subqueries {
86 predicates.push(subquery);
87 }
88 }
89 _ => (),
90 }
91 Ok(())
92 })
93 }
94
95 fn walk_scalar(expr: &mut HirScalarExpr) -> Result<(), RecursionLimitError> {
96 expr.try_visit_mut_post(&mut |expr| {
97 match expr {
98 HirScalarExpr::Exists(input) | HirScalarExpr::Select(input) => {
99 walk_relation(input)?
100 }
101 _ => (),
102 }
103 Ok(())
104 })
105 }
106
107 fn contains_subquery(expr: &HirScalarExpr) -> Result<bool, RecursionLimitError> {
108 let mut found = false;
109 expr.visit_pre(&mut |expr| match expr {
110 HirScalarExpr::Exists(_) | HirScalarExpr::Select(_) => found = true,
111 _ => (),
112 })?;
113 Ok(found)
114 }
115
116 /// Extracts subqueries from a conjunction into `out`.
117 ///
118 /// For example, given an expression like
119 ///
120 /// ```text
121 /// a = b AND EXISTS (<subquery 1>) AND c = d AND (<subquery 2>) = e
122 /// ```
123 ///
124 /// this function rewrites the expression to
125 ///
126 /// ```text
127 /// a = b AND true AND c = d AND true
128 /// ```
129 ///
130 /// and returns the expression fragments `EXISTS (<subquery 1>)` and
131 /// `(<subquery 2>) = e` in the `out` vector.
132 fn extract_conjuncted_subqueries(
133 expr: &mut HirScalarExpr,
134 out: &mut Vec<HirScalarExpr>,
135 ) -> Result<(), RecursionLimitError> {
136 match expr {
137 HirScalarExpr::CallVariadic {
138 func: VariadicFunc::And,
139 exprs,
140 } => {
141 exprs
142 .into_iter()
143 .try_for_each(|e| extract_conjuncted_subqueries(e, out))?;
144 }
145 expr if contains_subquery(expr)? => {
146 out.push(mem::replace(expr, HirScalarExpr::literal_true()))
147 }
148 _ => (),
149 }
150 Ok(())
151 }
152
153 walk_relation(expr)
154}
155
156/// Rewrites quantified comparisons into simpler EXISTS operators.
157///
158/// Note that this transformation is only valid when the expression is
159/// used in a context where the distinction between `FALSE` and `NULL`
160/// is immaterial, e.g., in a `WHERE` clause or a `CASE` condition, or
161/// when the inputs to the comparison are non-nullable. This function is careful
162/// to only apply the transformation when it is valid to do so.
163///
164/// ```ignore
165/// WHERE (SELECT any(<pred>) FROM <rel>)
166/// =>
167/// WHERE EXISTS(SELECT * FROM <rel> WHERE <pred>)
168///
169/// WHERE (SELECT all(<pred>) FROM <rel>)
170/// =>
171/// WHERE NOT EXISTS(SELECT * FROM <rel> WHERE (NOT <pred>) OR <pred> IS NULL)
172/// ```
173///
174/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
175/// M. Elhemali, et al.
176pub fn try_simplify_quantified_comparisons(
177 expr: &mut HirRelationExpr,
178) -> Result<(), RecursionLimitError> {
179 fn walk_relation(
180 expr: &mut HirRelationExpr,
181 outers: &[RelationType],
182 ) -> Result<(), RecursionLimitError> {
183 match expr {
184 HirRelationExpr::Map { scalars, input } => {
185 walk_relation(input, outers)?;
186 let mut outers = outers.to_vec();
187 outers.insert(0, input.typ(&outers, &NO_PARAMS));
188 for scalar in scalars {
189 walk_scalar(scalar, &outers, false)?;
190 let (inner, outers) = outers
191 .split_first_mut()
192 .expect("outers known to have at least one element");
193 let scalar_type = scalar.typ(outers, inner, &NO_PARAMS);
194 inner.column_types.push(scalar_type);
195 }
196 }
197 HirRelationExpr::Filter { predicates, input } => {
198 walk_relation(input, outers)?;
199 let mut outers = outers.to_vec();
200 outers.insert(0, input.typ(&outers, &NO_PARAMS));
201 for pred in predicates {
202 walk_scalar(pred, &outers, true)?;
203 }
204 }
205 HirRelationExpr::CallTable { exprs, .. } => {
206 let mut outers = outers.to_vec();
207 outers.insert(0, RelationType::empty());
208 for scalar in exprs {
209 walk_scalar(scalar, &outers, false)?;
210 }
211 }
212 HirRelationExpr::Join { left, right, .. } => {
213 walk_relation(left, outers)?;
214 let mut outers = outers.to_vec();
215 outers.insert(0, left.typ(&outers, &NO_PARAMS));
216 walk_relation(right, &outers)?;
217 }
218 expr => {
219 #[allow(deprecated)]
220 let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
221 walk_relation(expr, outers)
222 });
223 }
224 }
225 Ok(())
226 }
227
228 fn walk_scalar(
229 expr: &mut HirScalarExpr,
230 outers: &[RelationType],
231 mut in_filter: bool,
232 ) -> Result<(), RecursionLimitError> {
233 expr.try_visit_mut_pre(&mut |e| {
234 match e {
235 HirScalarExpr::Exists(input) => walk_relation(input, outers)?,
236 HirScalarExpr::Select(input) => {
237 walk_relation(input, outers)?;
238
239 // We're inside a `(SELECT ...)` subquery. Now let's see if
240 // it has the form `(SELECT <any|all>(...) FROM <input>)`.
241 // Ideally we could do this with one pattern, but Rust's pattern
242 // matching engine is not powerful enough, so we have to do this
243 // in stages; the early returns avoid brutal nesting.
244
245 let (func, expr, input) = match &mut **input {
246 HirRelationExpr::Reduce {
247 group_key,
248 aggregates,
249 input,
250 expected_group_size: _,
251 } if group_key.is_empty() && aggregates.len() == 1 => {
252 let agg = &mut aggregates[0];
253 (&agg.func, &mut agg.expr, input)
254 }
255 _ => return Ok(()),
256 };
257
258 if !in_filter && column_type(outers, input, expr).nullable {
259 // Unless we're directly inside a WHERE, this
260 // transformation is only valid if the expression involved
261 // is non-nullable.
262 return Ok(());
263 }
264
265 match func {
266 AggregateFunc::Any => {
267 // Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
268 // `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
269 *e = input.take().filter(vec![expr.take()]).exists();
270 }
271 AggregateFunc::All => {
272 // Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
273 // `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
274 //
275 // Note that negation of <expr> alone is insufficient.
276 // Consider that `WHERE <pred>` filters out rows if
277 // `<pred>` is false *or* null. To invert the test, we
278 // need `NOT <pred> OR <pred> IS NULL`.
279 let expr = expr.take();
280 let filter = expr.clone().not().or(expr.call_is_null());
281 *e = input.take().filter(vec![filter]).exists().not();
282 }
283 _ => (),
284 }
285 }
286 _ => {
287 // As soon as we see *any* scalar expression, we are no longer
288 // directly inside a filter.
289 in_filter = false;
290 }
291 }
292 Ok(())
293 })
294 }
295
296 walk_relation(expr, &[])
297}
298
299/// An empty parameter type map.
300///
301/// These transformations are expected to run after parameters are bound, so
302/// there is no need to provide any parameter type information.
303static NO_PARAMS: LazyLock<BTreeMap<usize, ScalarType>> = LazyLock::new(BTreeMap::new);
304
305fn column_type(
306 outers: &[RelationType],
307 inner: &HirRelationExpr,
308 expr: &HirScalarExpr,
309) -> ColumnType {
310 let inner_type = inner.typ(outers, &NO_PARAMS);
311 expr.typ(outers, &inner_type, &NO_PARAMS)
312}
313
314impl HirScalarExpr {
315 /// Similar to `MirScalarExpr::support`, but adapted to `HirScalarExpr` in a special way: it
316 /// considers column references that target the root level.
317 /// (See `visit_columns_referring_to_root_level`.)
318 fn support(&self) -> Vec<usize> {
319 let mut result = Vec::new();
320 self.visit_columns_referring_to_root_level(&mut |c| result.push(c));
321 result
322 }
323
324 /// Changes column references in `self` by the given remapping.
325 /// Panics if a referred column is not present in `idx_map`!
326 fn remap(mut self, idx_map: &BTreeMap<usize, usize>) -> HirScalarExpr {
327 self.visit_columns_referring_to_root_level_mut(&mut |c| {
328 *c = idx_map[c];
329 });
330 self
331 }
332}
333
334/// # Aims and scope
335///
336/// The aim here is to amortize the overhead of the MIR window function pattern
337/// (see `window_func_applied_to`) by fusing groups of window function calls such
338/// that each group can be performed by one instance of the window function MIR
339/// pattern.
340///
341/// For now, we fuse only value window function calls and window aggregations.
342/// (We probably won't need to fuse scalar window functions for a long time.)
343///
344/// For now, we can fuse value window function calls and window aggregations where the
345/// A. partition by
346/// B. order by
347/// C. window frame
348/// D. ignore nulls for value window functions and distinct for window aggregations
349/// are all the same. (See `extract_options`.)
350/// (Later, we could improve this to only need A. to be the same. This would require
351/// much more code changes, because then we'd have to blow up `ValueWindowExpr`.
352/// TODO: As a much simpler intermediate step, at least we should ignore options that
353/// don't matter. For example, we should be able to fuse a `lag` that has a default
354/// frame with a `first_value` that has some custom frame, because `lag` is not
355/// affected by the frame.)
356/// Note that we fuse value window function calls and window aggregations separately.
357///
358/// # Implementation
359///
360/// At a high level, what we are going to do is look for Maps with more than one window function
361/// calls, and for each Map
362/// - remove some groups of window function call expressions from the Map's `scalars`;
363/// - insert a fused version of each group;
364/// - insert some expressions that decompose the results of the fused calls;
365/// - update some column references in `scalars`: those that refer to window function results that
366/// participated in fusion, as well as those that refer to columns that moved around due to
367/// removing and inserting expressions.
368/// - insert a Project above the matched Map to permute columns back to their original places.
369///
370/// It would be tempting to find groups simply by taking a list of all window function calls
371/// and calling `group_by` with a key function that extracts the above A. B. C. D. properties,
372/// but a complication is that the possible groups that we could theoretically fuse overlap.
373/// This is because when forming groups we need to also take into account column references
374/// that point inside the same Map. For example, imagine a Map with the following scalar
375/// expressions:
376/// C1, E1, C2, C3, where
377/// - E1 refers to C1
378/// - C3 refers to E1.
379/// In this situation, we could either
380/// - fuse C1 and C2, and put the fused expression in the place of C1 (so that E1 can keep referring
381/// to it);
382/// - or fuse C2 and C3.
383/// However, we can't fuse all of C1, C2, C3 into one call, because then there would be
384/// no appropriate place for the fused expression: it would have to be both before and after E1.
385///
386/// So, how we actually form the groups is that, keeping track of a list of non-overlapping groups,
387/// we go through `scalars`, try to put each expression into each of our groups, and the first of
388/// these succeed. When trying to put an expression into a group, we need to be mindful about column
389/// references inside the same Map, as noted above. A constraint that we impose on ourselves for
390/// sanity is that the fused version of each group will be inserted at the place where the first
391/// element of the group originally was. This means that the only condition that we need to check on
392/// column references when adding an expression to a group is that all column references in a group
393/// should be to columns that are earlier than the first element of the group. (No need to check
394/// column references in the other direction, i.e., references in other expressions that refer to
395/// columns in the group.)
396pub fn fuse_window_functions(
397 root: &mut HirRelationExpr,
398 _context: &crate::plan::lowering::Context,
399) -> Result<(), RecursionLimitError> {
400 /// Those options of a window function call that are relevant for fusion.
401 #[derive(PartialEq, Eq)]
402 enum WindowFuncCallOptions {
403 Value(ValueWindowFuncCallOptions),
404 Agg(AggregateWindowFuncCallOptions),
405 }
406 #[derive(PartialEq, Eq)]
407 struct ValueWindowFuncCallOptions {
408 partition_by: Vec<HirScalarExpr>,
409 outer_order_by: Vec<HirScalarExpr>,
410 inner_order_by: Vec<ColumnOrder>,
411 window_frame: WindowFrame,
412 ignore_nulls: bool,
413 }
414 #[derive(PartialEq, Eq)]
415 struct AggregateWindowFuncCallOptions {
416 partition_by: Vec<HirScalarExpr>,
417 outer_order_by: Vec<HirScalarExpr>,
418 inner_order_by: Vec<ColumnOrder>,
419 window_frame: WindowFrame,
420 distinct: bool,
421 }
422
423 /// Helper function to extract the above options.
424 fn extract_options(call: &HirScalarExpr) -> WindowFuncCallOptions {
425 match call {
426 HirScalarExpr::Windowing(WindowExpr {
427 func:
428 WindowExprType::Value(ValueWindowExpr {
429 order_by: inner_order_by,
430 window_frame,
431 ignore_nulls,
432 func: _,
433 args: _,
434 }),
435 partition_by,
436 order_by: outer_order_by,
437 }) => WindowFuncCallOptions::Value(ValueWindowFuncCallOptions {
438 partition_by: partition_by.clone(),
439 outer_order_by: outer_order_by.clone(),
440 inner_order_by: inner_order_by.clone(),
441 window_frame: window_frame.clone(),
442 ignore_nulls: ignore_nulls.clone(),
443 }),
444 HirScalarExpr::Windowing(WindowExpr {
445 func:
446 WindowExprType::Aggregate(AggregateWindowExpr {
447 aggregate_expr:
448 AggregateExpr {
449 distinct,
450 func: _,
451 expr: _,
452 },
453 order_by: inner_order_by,
454 window_frame,
455 }),
456 partition_by,
457 order_by: outer_order_by,
458 }) => WindowFuncCallOptions::Agg(AggregateWindowFuncCallOptions {
459 partition_by: partition_by.clone(),
460 outer_order_by: outer_order_by.clone(),
461 inner_order_by: inner_order_by.clone(),
462 window_frame: window_frame.clone(),
463 distinct: distinct.clone(),
464 }),
465 _ => panic!(
466 "extract_options should only be called on value window functions or window aggregations"
467 ),
468 }
469 }
470
471 struct FusionGroup {
472 /// The original column index of the first element of the group. (This is an index into the
473 /// Map's `scalars` plus the arity of the Map's input.)
474 first_col: usize,
475 /// The options of all the window function calls in the group. (Must be the same for all the
476 /// calls.)
477 options: WindowFuncCallOptions,
478 /// The calls in the group, with their original column indexes.
479 calls: Vec<(usize, HirScalarExpr)>,
480 }
481
482 impl FusionGroup {
483 /// Creates a window function call that is a fused version of all the calls in the group.
484 /// `new_col` is the column index where the fused call will be inserted at.
485 fn fuse(self, new_col: usize) -> (HirScalarExpr, Vec<HirScalarExpr>) {
486 let fused = match self.options {
487 WindowFuncCallOptions::Value(options) => {
488 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
489 .calls
490 .iter()
491 .map(|(_idx, call)| {
492 if let HirScalarExpr::Windowing(WindowExpr {
493 func:
494 WindowExprType::Value(ValueWindowExpr {
495 func,
496 args,
497 order_by: _,
498 window_frame: _,
499 ignore_nulls: _,
500 }),
501 partition_by: _,
502 order_by: _,
503 }) = call
504 {
505 (func.clone(), (**args).clone())
506 } else {
507 panic!("unknown window function in FusionGroup")
508 }
509 })
510 .unzip();
511 let fused_args = HirScalarExpr::CallVariadic {
512 func: VariadicFunc::RecordCreate {
513 // These field names are not important, because this record will only be an
514 // intermediate expression, which we'll manipulate further before it ends up
515 // anywhere where a column name would be visible.
516 field_names: iter::repeat(ColumnName::from(""))
517 .take(fused_args.len())
518 .collect(),
519 },
520 exprs: fused_args,
521 };
522 HirScalarExpr::Windowing(WindowExpr {
523 func: WindowExprType::Value(ValueWindowExpr {
524 func: ValueWindowFunc::Fused(fused_funcs),
525 args: Box::new(fused_args),
526 order_by: options.inner_order_by,
527 window_frame: options.window_frame,
528 ignore_nulls: options.ignore_nulls,
529 }),
530 partition_by: options.partition_by,
531 order_by: options.outer_order_by,
532 })
533 }
534 WindowFuncCallOptions::Agg(options) => {
535 let (fused_funcs, fused_args): (Vec<_>, Vec<_>) = self
536 .calls
537 .iter()
538 .map(|(_idx, call)| {
539 if let HirScalarExpr::Windowing(WindowExpr {
540 func:
541 WindowExprType::Aggregate(AggregateWindowExpr {
542 aggregate_expr:
543 AggregateExpr {
544 func,
545 expr,
546 distinct: _,
547 },
548 order_by: _,
549 window_frame: _,
550 }),
551 partition_by: _,
552 order_by: _,
553 }) = call
554 {
555 (func.clone(), (**expr).clone())
556 } else {
557 panic!("unknown window function in FusionGroup")
558 }
559 })
560 .unzip();
561 let fused_args = HirScalarExpr::CallVariadic {
562 func: VariadicFunc::RecordCreate {
563 field_names: iter::repeat(ColumnName::from(""))
564 .take(fused_args.len())
565 .collect(),
566 },
567 exprs: fused_args,
568 };
569 HirScalarExpr::Windowing(WindowExpr {
570 func: WindowExprType::Aggregate(AggregateWindowExpr {
571 aggregate_expr: AggregateExpr {
572 func: AggregateFunc::FusedWindowAgg { funcs: fused_funcs },
573 expr: Box::new(fused_args),
574 distinct: options.distinct,
575 },
576 order_by: options.inner_order_by,
577 window_frame: options.window_frame,
578 }),
579 partition_by: options.partition_by,
580 order_by: options.outer_order_by,
581 })
582 }
583 };
584
585 let decompositions = (0..self.calls.len())
586 .map(|field| HirScalarExpr::CallUnary {
587 func: UnaryFunc::RecordGet(mz_expr::func::RecordGet(field)),
588 expr: Box::new(HirScalarExpr::Column(ColumnRef {
589 level: 0,
590 column: new_col,
591 })),
592 })
593 .collect();
594
595 (fused, decompositions)
596 }
597 }
598
599 let is_value_or_agg_window_func_call = |scalar_expr: &HirScalarExpr| -> bool {
600 // Look for calls only at the root of scalar expressions. This is enough
601 // because they are always there, see 72e84bb78.
602 match scalar_expr {
603 HirScalarExpr::Windowing(WindowExpr {
604 func: WindowExprType::Value(ValueWindowExpr { func, .. }),
605 ..
606 }) => {
607 // Exclude those calls that are already fused. (We shouldn't currently
608 // encounter these, because we just do one pass, but it's better to be
609 // robust against future code changes.)
610 !matches!(func, ValueWindowFunc::Fused(..))
611 }
612 HirScalarExpr::Windowing(WindowExpr {
613 func:
614 WindowExprType::Aggregate(AggregateWindowExpr {
615 aggregate_expr: AggregateExpr { func, .. },
616 ..
617 }),
618 ..
619 }) => !matches!(func, AggregateFunc::FusedWindowAgg { .. }),
620 _ => false,
621 }
622 };
623
624 root.try_visit_mut_post(&mut |rel_expr| {
625 match rel_expr {
626 HirRelationExpr::Map { input, scalars } => {
627 // There will be various variable names involving `idx` or `col`:
628 // - `idx` will always be an index into `scalars` or something similar,
629 // - `col` will always be a column index,
630 // which is often `arity_before_map` + an index into `scalars`.
631 let arity_before_map = input.arity();
632 let orig_num_scalars = scalars.len();
633
634 // Collect all value window function calls and window aggregations with their column
635 // indexes.
636 let value_or_agg_window_func_calls = scalars
637 .iter()
638 .enumerate()
639 .filter(|(_idx, scalar_expr)| is_value_or_agg_window_func_call(scalar_expr))
640 .map(|(idx, call)| (idx + arity_before_map, call.clone()))
641 .collect_vec();
642 // Exit early if obviously no chance for fusion.
643 if value_or_agg_window_func_calls.len() <= 1 {
644 // Note that we are doing this only for performance. All plans should be exactly
645 // the same even if we comment out the following line.
646 return Ok(());
647 }
648
649 // Determine the fusion groups. (Each group will later be fused into one window
650 // function call.)
651 // Note that this has a quadratic run time with value_or_agg_window_func_calls in
652 // the worst case. However, this is fine even with 1000 window function calls.
653 let mut groups: Vec<FusionGroup> = Vec::new();
654 for (col, call) in value_or_agg_window_func_calls {
655 let options = extract_options(&call);
656 let support = call.support();
657 let to_fuse_with = groups
658 .iter_mut()
659 .filter(|group| {
660 group.options == options && support.iter().all(|c| *c < group.first_col)
661 })
662 .next();
663 if let Some(group) = to_fuse_with {
664 group.calls.push((col, call.clone()));
665 } else {
666 groups.push(FusionGroup {
667 first_col: col,
668 options,
669 calls: vec![(col, call.clone())],
670 });
671 }
672 }
673
674 // No fusion to do on groups of 1.
675 groups.retain(|g| g.calls.len() > 1);
676
677 let removals: BTreeSet<usize> = groups
678 .iter()
679 .flat_map(|g| g.calls.iter().map(|(col, _)| *col))
680 .collect();
681
682 // Mutate `scalars`.
683 // We do this by simultaneously iterating through `scalars` and `groups`. (Note that
684 // `groups` is already sorted by `first_col` due to the way it was constructed.)
685 // We also compute a remapping of old indexes to new indexes as we go.
686 let mut groups_it = groups.drain(..).peekable();
687 let mut group = groups_it.next();
688 let mut remap = BTreeMap::new();
689 remap.extend((0..arity_before_map).map(|col| (col, col)));
690 let mut new_col: usize = arity_before_map;
691 let mut new_scalars = Vec::new();
692 for (old_col, e) in scalars
693 .drain(..)
694 .enumerate()
695 .map(|(idx, e)| (idx + arity_before_map, e))
696 {
697 if group.as_ref().is_some_and(|g| g.first_col == old_col) {
698 // The current expression will be fused away, and a fused expression will
699 // appear in its place. Additionally, some new expressions will be inserted
700 // after the fused expression, to decompose the record that is the result of
701 // the fused call.
702 assert!(removals.contains(&old_col));
703 let group_unwrapped = group.expect("checked above");
704 let calls_cols = group_unwrapped
705 .calls
706 .iter()
707 .map(|(col, _call)| *col)
708 .collect_vec();
709 let (fused, decompositions) = group_unwrapped.fuse(new_col);
710 new_scalars.push(fused.remap(&remap));
711 new_scalars.extend(decompositions); // (no remapping needed)
712 new_col += 1;
713 for call_old_col in calls_cols {
714 let present = remap.insert(call_old_col, new_col);
715 assert!(present.is_none());
716 new_col += 1;
717 }
718 group = groups_it.next();
719 } else if removals.contains(&old_col) {
720 assert!(remap.contains_key(&old_col));
721 } else {
722 new_scalars.push(e.remap(&remap));
723 let present = remap.insert(old_col, new_col);
724 assert!(present.is_none());
725 new_col += 1;
726 }
727 }
728 *scalars = new_scalars;
729 assert_eq!(remap.len(), arity_before_map + orig_num_scalars);
730
731 // Add a project to permute columns back to their original places.
732 *rel_expr = rel_expr.take().project(
733 (0..arity_before_map)
734 .chain((0..orig_num_scalars).map(|idx| {
735 *remap
736 .get(&(idx + arity_before_map))
737 .expect("all columns should be present by now")
738 }))
739 .collect(),
740 );
741
742 assert_eq!(rel_expr.arity(), arity_before_map + orig_num_scalars);
743 }
744 _ => {}
745 }
746 Ok(())
747 })
748}