mz_sql/plan/
lowering.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::visit::Visit;
44use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
45use mz_ore::collections::CollectionExt;
46use mz_ore::stack::maybe_grow;
47use mz_repr::*;
48
49use crate::optimizer_metrics::OptimizerMetrics;
50use crate::plan::hir::{
51    AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
52};
53use crate::plan::{PlanError, transform_hir};
54use crate::session::vars::SystemVars;
55
56mod variadic_left;
57
58/// Maps a leveled column reference to a specific column.
59///
60/// Leveled column references are nested, so that larger levels are
61/// found early in a record and level zero is found at the end.
62///
63/// The column map only stores references for levels greater than zero,
64/// and column references at level zero simply start at the first column
65/// after all prior references.
66#[derive(Debug, Clone)]
67struct ColumnMap {
68    inner: BTreeMap<ColumnRef, usize>,
69}
70
71impl ColumnMap {
72    fn empty() -> ColumnMap {
73        Self::new(BTreeMap::new())
74    }
75
76    fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
77        ColumnMap { inner }
78    }
79
80    fn get(&self, col_ref: &ColumnRef) -> usize {
81        if col_ref.level == 0 {
82            self.inner.len() + col_ref.column
83        } else {
84            self.inner[col_ref]
85        }
86    }
87
88    fn len(&self) -> usize {
89        self.inner.len()
90    }
91
92    /// Updates references in the `ColumnMap` for use in a nested scope. The
93    /// provided `arity` must specify the arity of the current scope.
94    fn enter_scope(&self, arity: usize) -> ColumnMap {
95        // From the perspective of the nested scope, all existing column
96        // references will be one level greater.
97        let existing = self
98            .inner
99            .clone()
100            .into_iter()
101            .update(|(col, _i)| col.level += 1);
102
103        // All columns in the current scope become explicit entries in the
104        // immediate parent scope.
105        let new = (0..arity).map(|i| {
106            (
107                ColumnRef {
108                    level: 1,
109                    column: i,
110                },
111                self.len() + i,
112            )
113        });
114
115        ColumnMap::new(existing.chain(new).collect())
116    }
117}
118
119/// Map with the CTEs currently in scope.
120type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
121
122/// Information about needed when finding a reference to a CTE in scope.
123#[derive(Clone)]
124struct CteDesc {
125    /// The new ID assigned to the lowered version of the CTE, which may not match
126    /// the ID of the input CTE.
127    new_id: mz_expr::LocalId,
128    /// The relation type of the CTE including the columns from the outer
129    /// context at the beginning.
130    relation_type: SqlRelationType,
131    /// The outer relation the CTE was applied to.
132    outer_relation: MirRelationExpr,
133}
134
135#[derive(Debug)]
136pub struct Config {
137    /// Enable outer join lowering implemented in database-issues#6747.
138    pub enable_new_outer_join_lowering: bool,
139    /// Enable outer join lowering implemented in database-issues#7561.
140    pub enable_variadic_left_join_lowering: bool,
141    pub enable_guard_subquery_tablefunc: bool,
142}
143
144impl From<&SystemVars> for Config {
145    fn from(vars: &SystemVars) -> Self {
146        Self {
147            enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
148            enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
149            enable_guard_subquery_tablefunc: vars.enable_guard_subquery_tablefunc(),
150        }
151    }
152}
153
154/// Context passed to the lowering. This is wired to most parts of the lowering.
155pub(crate) struct Context<'a> {
156    /// Feature flags affecting the behavior of lowering.
157    pub config: &'a Config,
158    /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
159    /// simply don't write metrics.
160    pub metrics: Option<&'a OptimizerMetrics>,
161}
162
163impl HirRelationExpr {
164    /// Rewrite `self` into a `MirRelationExpr`.
165    /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
166    #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
167    pub fn lower<C: Into<Config>>(
168        self,
169        config: C,
170        metrics: Option<&OptimizerMetrics>,
171    ) -> Result<MirRelationExpr, PlanError> {
172        let context = Context {
173            config: &config.into(),
174            metrics,
175        };
176        let result = match self {
177            // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
178            // to ensure that the downstream optimizer can easily bypass most
179            // irrelevant optimizations (e.g. reduce folding) for this expression
180            // without having to re-learn the fact that it is just a constant,
181            // as it would if the constant were wrapped in a Let-Get pair.
182            HirRelationExpr::Constant { rows, typ } => {
183                let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
184                MirRelationExpr::Constant {
185                    rows: Ok(rows),
186                    typ,
187                }
188            }
189            mut other => {
190                let mut id_gen = mz_ore::id_gen::IdGen::default();
191                transform_hir::split_subquery_predicates(&mut other)?;
192                transform_hir::try_simplify_quantified_comparisons(&mut other)?;
193                transform_hir::fuse_window_functions(&mut other, &context)?;
194                MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![])).let_in(
195                    &mut id_gen,
196                    |id_gen, get_outer| {
197                        other.applied_to(
198                            id_gen,
199                            get_outer,
200                            &ColumnMap::empty(),
201                            &mut CteMap::new(),
202                            &context,
203                        )
204                    },
205                )?
206            }
207        };
208
209        mz_repr::explain::trace_plan(&result);
210
211        Ok(result)
212    }
213
214    /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
215    ///
216    /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
217    /// When `self` references columns of `get_outer`, much more work needs to occur.
218    ///
219    /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
220    /// perhaps not all of them. It should be used as the basis of resolving column references,
221    /// but care must be taken when adding new columns that `get_outer.arity()` is where they
222    /// will start, rather than any function of `col_map`.
223    ///
224    /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
225    /// assignment of values to outer rows.
226    fn applied_to(
227        self,
228        id_gen: &mut mz_ore::id_gen::IdGen,
229        get_outer: MirRelationExpr,
230        col_map: &ColumnMap,
231        cte_map: &mut CteMap,
232        context: &Context,
233    ) -> Result<MirRelationExpr, PlanError> {
234        maybe_grow(|| {
235            use MirRelationExpr as SR;
236
237            use HirRelationExpr::*;
238
239            if let MirRelationExpr::Get { .. } = &get_outer {
240            } else {
241                panic!(
242                    "get_outer: expected a MirRelationExpr::Get, found\n{}",
243                    get_outer.pretty(),
244                );
245            }
246            assert_eq!(col_map.len(), get_outer.arity());
247            Ok(match self {
248                Constant { rows, typ } => {
249                    // Constant expressions are not correlated with `get_outer`, and should be cross-products.
250                    get_outer.product(SR::Constant {
251                        rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
252                        typ,
253                    })
254                }
255                Get { id, typ } => match id {
256                    mz_expr::Id::Local(local_id) => {
257                        let cte_desc = cte_map.get(&local_id).unwrap();
258                        let get_cte = SR::Get {
259                            id: mz_expr::Id::Local(cte_desc.new_id.clone()),
260                            typ: cte_desc.relation_type.clone(),
261                            access_strategy: AccessStrategy::UnknownOrLocal,
262                        };
263                        if get_outer == cte_desc.outer_relation {
264                            // If the CTE was applied to the same exact relation, we can safely
265                            // return a `Get` relation.
266                            get_cte
267                        } else {
268                            // Otherwise, the new outer relation may contain more columns from some
269                            // intermediate scope placed between the definition of the CTE and this
270                            // reference of the CTE and/or more operations applied on top of the
271                            // outer relation.
272                            //
273                            // An example of the latter is the following query:
274                            //
275                            // SELECT *
276                            // FROM x,
277                            //      LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
278                            //              SELECT (SELECT m FROM a) FROM y) b;
279                            //
280                            // When the CTE is lowered, the outer relation is `Get x`. But then,
281                            // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
282                            // which has the same cardinality as `Get x`.
283                            //
284                            // In any case, `get_outer` is guaranteed to contain the columns of the
285                            // outer relation the CTE was applied to at its prefix. Since, we must
286                            // return a relation containing `get_outer`'s column at the beginning,
287                            // we must build a join between `get_outer` and `get_cte` on their common
288                            // columns.
289                            let oa = get_outer.arity();
290                            let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
291                            let equivalences = (0..cte_outer_columns)
292                                .map(|pos| {
293                                    vec![
294                                        MirScalarExpr::column(pos),
295                                        MirScalarExpr::column(pos + oa),
296                                    ]
297                                })
298                                .collect();
299
300                            // Project out the second copy of the common between `get_outer` and
301                            // `cte_desc.outer_relation`.
302                            let projection = (0..oa)
303                                .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
304                                .collect_vec();
305                            SR::join_scalars(vec![get_outer, get_cte], equivalences)
306                                .project(projection)
307                        }
308                    }
309                    _ => {
310                        // Get statements are only to external sources, and are not correlated with `get_outer`.
311                        get_outer.product(SR::Get {
312                            id,
313                            typ,
314                            access_strategy: AccessStrategy::UnknownOrLocal,
315                        })
316                    }
317                },
318                Let {
319                    name: _,
320                    id,
321                    value,
322                    body,
323                } => {
324                    let value =
325                        value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
326                    value.let_in(id_gen, |id_gen, get_value| {
327                        let (new_id, typ) = if let MirRelationExpr::Get {
328                            id: mz_expr::Id::Local(id),
329                            typ,
330                            ..
331                        } = get_value
332                        {
333                            (id, typ)
334                        } else {
335                            panic!(
336                                "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
337                                get_value.pretty(),
338                            );
339                        };
340                        // Add the information about the CTE to the map and remove it when
341                        // it goes out of scope.
342                        let old_value = cte_map.insert(
343                            id.clone(),
344                            CteDesc {
345                                new_id,
346                                relation_type: typ,
347                                outer_relation: get_outer.clone(),
348                            },
349                        );
350                        let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
351                        if let Some(old_value) = old_value {
352                            cte_map.insert(id, old_value);
353                        } else {
354                            cte_map.remove(&id);
355                        }
356                        body
357                    })?
358                }
359                LetRec {
360                    limit,
361                    bindings,
362                    body,
363                } => {
364                    let num_bindings = bindings.len();
365
366                    // We use the outer type with the HIR types to form MIR CTE types.
367                    let outer_column_types = get_outer.typ().column_types;
368
369                    // Rename and introduce all bindings.
370                    let mut shadowed_bindings = Vec::with_capacity(num_bindings);
371                    let mut mir_ids = Vec::with_capacity(num_bindings);
372                    for (_name, id, _value, typ) in bindings.iter() {
373                        let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
374                        mir_ids.push(mir_id);
375                        let shadowed = cte_map.insert(
376                            id.clone(),
377                            CteDesc {
378                                new_id: mir_id,
379                                relation_type: SqlRelationType::new(
380                                    outer_column_types
381                                        .iter()
382                                        .cloned()
383                                        .chain(typ.column_types.iter().cloned())
384                                        .collect::<Vec<_>>(),
385                                ),
386                                outer_relation: get_outer.clone(),
387                            },
388                        );
389                        shadowed_bindings.push((*id, shadowed));
390                    }
391
392                    let mut mir_values = Vec::with_capacity(num_bindings);
393                    for (_name, _id, value, _typ) in bindings.into_iter() {
394                        mir_values.push(value.applied_to(
395                            id_gen,
396                            get_outer.clone(),
397                            col_map,
398                            cte_map,
399                            context,
400                        )?);
401                    }
402
403                    let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
404
405                    // Remove our bindings and reinstate any shadowed bindings.
406                    for (id, shadowed) in shadowed_bindings {
407                        if let Some(shadowed) = shadowed {
408                            cte_map.insert(id, shadowed);
409                        } else {
410                            cte_map.remove(&id);
411                        }
412                    }
413
414                    MirRelationExpr::LetRec {
415                        ids: mir_ids,
416                        values: mir_values,
417                        // Copy the limit to each binding.
418                        limits: repeat(limit).take(num_bindings).collect(),
419                        body: Box::new(mir_body),
420                    }
421                }
422                Project { input, outputs } => {
423                    // Projections should be applied to the decorrelated `inner`, and to its columns,
424                    // which means rebasing `outputs` to start `get_outer.arity()` columns later.
425                    let input =
426                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
427                    let outputs = (0..get_outer.arity())
428                        .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
429                        .collect::<Vec<_>>();
430                    input.project(outputs)
431                }
432                Map { input, mut scalars } => {
433                    // Scalar expressions may contain correlated subqueries. We must be cautious!
434
435                    // We lower scalars in chunks, and must keep track of the
436                    // arity of the HIR fragments lowered so far.
437                    let mut lowered_arity = input.arity();
438
439                    let mut input =
440                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
441
442                    // Lower subqueries in maximally sized batches, such as no subquery in the current
443                    // batch depends on columns from the same batch.
444                    // Note that subqueries in this projection may reference columns added by this
445                    // Map operator, so we need to ensure these columns exist before lowering the
446                    // subquery.
447                    while !scalars.is_empty() {
448                        let end_idx = scalars
449                            .iter_mut()
450                            .position(|s| {
451                                let mut requires_nonexistent_column = false;
452                                #[allow(deprecated)]
453                                s.visit_columns(0, &mut |depth, col| {
454                                    if col.level == depth {
455                                        requires_nonexistent_column |= col.column >= lowered_arity
456                                    }
457                                });
458                                requires_nonexistent_column
459                            })
460                            .unwrap_or(scalars.len());
461                        assert!(
462                            end_idx > 0,
463                            "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
464                            lowered_arity,
465                            scalars
466                        );
467
468                        lowered_arity = lowered_arity + end_idx;
469                        let scalars = scalars.drain(0..end_idx).collect_vec();
470
471                        let old_arity = input.arity();
472                        let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
473                            &scalars, id_gen, col_map, cte_map, input, context,
474                        )?;
475                        input = with_subqueries;
476
477                        // We will proceed sequentially through the scalar expressions, for each transforming
478                        // the decorrelated `input` into a relation with potentially more columns capable of
479                        // addressing the needs of the scalar expression.
480                        // Having done so, we add the scalar value of interest and trim off any other newly
481                        // added columns.
482                        //
483                        // The sequential traversal is present as expressions are allowed to depend on the
484                        // values of prior expressions.
485                        let mut scalar_columns = Vec::new();
486                        for scalar in scalars {
487                            let scalar = scalar.applied_to(
488                                id_gen,
489                                col_map,
490                                cte_map,
491                                &mut input,
492                                &Some(&subquery_map),
493                                context,
494                            )?;
495                            input = input.map_one(scalar);
496                            scalar_columns.push(input.arity() - 1);
497                        }
498
499                        // Discard any new columns added by the lowering of the scalar expressions
500                        input = input.project((0..old_arity).chain(scalar_columns).collect());
501                    }
502
503                    input
504                }
505                CallTable { func, exprs } => {
506                    // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
507                    // allowed to refer to the results of previous expressions, and we have a simpler
508                    // implementation that appends all relevant columns first, then applies the flatmap
509                    // operator to the result, then strips off any columns introduce by subqueries.
510
511                    let mut input = get_outer;
512                    let old_arity = input.arity();
513
514                    let exprs = exprs
515                        .into_iter()
516                        .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
517                        .collect::<Result<Vec<_>, _>>()?;
518
519                    let new_arity = input.arity();
520                    let output_arity = func.output_arity();
521                    input = input.flat_map(func, exprs);
522                    if old_arity != new_arity {
523                        // this means we added some columns to handle subqueries, and now we need to get rid of them
524                        input = input.project(
525                            (0..old_arity)
526                                .chain(new_arity..new_arity + output_arity)
527                                .collect(),
528                        );
529                    }
530                    input
531                }
532                Filter { input, predicates } => {
533                    // Filter expressions may contain correlated subqueries.
534                    // We extend `get_outer` with sufficient values to determine the value of the predicate,
535                    // then filter the results, then strip off any columns that were added for this purpose.
536                    let mut input =
537                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
538                    for predicate in predicates {
539                        let old_arity = input.arity();
540                        let predicate = predicate
541                            .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
542                        let new_arity = input.arity();
543                        input = input.filter(vec![predicate]);
544                        if old_arity != new_arity {
545                            // this means we added some columns to handle subqueries, and now we need to get rid of them
546                            input = input.project((0..old_arity).collect());
547                        }
548                    }
549                    input
550                }
551                Join {
552                    left,
553                    right,
554                    on,
555                    kind,
556                } if right.is_correlated() => {
557                    // A correlated join is a join in which the right expression has
558                    // access to the columns in the left expression. It turns out
559                    // this is *exactly* our branch operator, plus some additional
560                    // null handling in the case of left joins. (Right and full
561                    // lateral joins are not permitted.)
562                    //
563                    // As with normal joins, the `on` predicate may be correlated,
564                    // and we treat it as a filter that follows the branch.
565
566                    assert!(kind.can_be_correlated());
567
568                    let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
569                    left.let_in(id_gen, |id_gen, get_left| {
570                        let apply_requires_distinct_outer = false;
571                        let mut join = branch(
572                            id_gen,
573                            get_left.clone(),
574                            col_map,
575                            cte_map,
576                            *right,
577                            apply_requires_distinct_outer,
578                            context,
579                            |id_gen, right, get_left, col_map, cte_map, context| {
580                                right.applied_to(id_gen, get_left, col_map, cte_map, context)
581                            },
582                        )?;
583
584                        // Plan the `on` predicate.
585                        let old_arity = join.arity();
586                        let on =
587                            on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
588                        join = join.filter(vec![on]);
589                        let new_arity = join.arity();
590                        if old_arity != new_arity {
591                            // This means we added some columns to handle
592                            // subqueries, and now we need to get rid of them.
593                            join = join.project((0..old_arity).collect());
594                        }
595
596                        // If a left join, reintroduce any rows from the left that
597                        // are missing, with nulls filled in for the right columns.
598                        if let JoinKind::LeftOuter { .. } = kind {
599                            let default = join
600                                .typ()
601                                .column_types
602                                .into_iter()
603                                .skip(get_left.arity())
604                                .map(|typ| (Datum::Null, typ.scalar_type))
605                                .collect();
606                            get_left.lookup(id_gen, join, default)
607                        } else {
608                            Ok::<_, PlanError>(join)
609                        }
610                    })?
611                }
612                Join {
613                    left,
614                    right,
615                    on,
616                    kind,
617                } => {
618                    if context.config.enable_variadic_left_join_lowering {
619                        // Attempt to extract a stack of left joins.
620                        if let JoinKind::LeftOuter = kind {
621                            let mut rights = vec![(&*right, &on)];
622                            let mut left_test = &left;
623                            while let Join {
624                                left,
625                                right,
626                                on,
627                                kind: JoinKind::LeftOuter,
628                            } = &**left_test
629                            {
630                                rights.push((&**right, on));
631                                left_test = left;
632                            }
633                            if rights.len() > 1 {
634                                // Defensively clone `cte_map` as it may be mutated.
635                                let cte_map_clone = cte_map.clone();
636                                if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
637                                    left_test,
638                                    rights,
639                                    id_gen,
640                                    get_outer.clone(),
641                                    col_map,
642                                    cte_map,
643                                    context,
644                                ) {
645                                    return Ok(magic);
646                                } else {
647                                    cte_map.clone_from(&cte_map_clone);
648                                }
649                            }
650                        }
651                    }
652
653                    // Both join expressions should be decorrelated, and then joined by their
654                    // leading columns to form only those pairs corresponding to the same row
655                    // of `get_outer`.
656                    //
657                    // The `on` predicate may contain correlated subqueries, and we treat it
658                    // as though it was a filter, with the caveat that we also translate outer
659                    // joins in this step. The post-filtration results need to be considered
660                    // against the records present in the left and right (decorrelated) inputs,
661                    // depending on the type of join.
662                    let oa = get_outer.arity();
663                    let left =
664                        left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
665                    let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
666                    let la = lt.len();
667                    left.let_in(id_gen, |id_gen, get_left| {
668                        let right_col_map = col_map.enter_scope(0);
669                        let right = right.applied_to(
670                            id_gen,
671                            get_outer.clone(),
672                            &right_col_map,
673                            cte_map,
674                            context,
675                        )?;
676                        let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
677                        let ra = rt.len();
678                        right.let_in(id_gen, |id_gen, get_right| {
679                            let mut product = SR::join(
680                                vec![get_left.clone(), get_right.clone()],
681                                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
682                            )
683                            // Project away the repeated copy of get_outer's columns.
684                            .project(
685                                (0..(oa + la))
686                                    .chain((oa + la + oa)..(oa + la + oa + ra))
687                                    .collect(),
688                            );
689
690                            // Decorrelate and lower the `on` clause.
691                            let on = on.applied_to(
692                                id_gen,
693                                col_map,
694                                cte_map,
695                                &mut product,
696                                &None,
697                                context,
698                            )?;
699                            // Collect the types of all subqueries appearing in
700                            // the `on` clause. The subquery results were
701                            // appended to `product` in the `on.applied_to(...)`
702                            // call above.
703                            let on_subquery_types = product
704                                .typ()
705                                .column_types
706                                .drain(oa + la + ra..)
707                                .collect_vec();
708                            // Remember if `on` had any subqueries.
709                            let on_has_subqueries = !on_subquery_types.is_empty();
710
711                            // Attempt an efficient equijoin implementation, in which outer joins are
712                            // more efficiently rendered than in general. This can return `None` if
713                            // such a plan is not possible, for example if `on` does not describe an
714                            // equijoin between columns of `left` and `right`.
715                            if kind != JoinKind::Inner {
716                                if let Some(joined) = attempt_outer_equijoin(
717                                    get_left.clone(),
718                                    get_right.clone(),
719                                    on.clone(),
720                                    on_subquery_types,
721                                    kind.clone(),
722                                    oa,
723                                    id_gen,
724                                    context,
725                                )? {
726                                    if let Some(metrics) = context.metrics {
727                                        metrics.inc_outer_join_lowering("equi");
728                                    }
729                                    return Ok(joined);
730                                }
731                            }
732
733                            // Otherwise, perform a more general join.
734                            if let Some(metrics) = context.metrics {
735                                metrics.inc_outer_join_lowering("general");
736                            }
737                            let mut join = product.filter(vec![on]);
738                            if on_has_subqueries {
739                                // This means that `on.applied_to(...)` appended
740                                // some columns to handle subqueries, and now we
741                                // need to get rid of them.
742                                join = join.project((0..oa + la + ra).collect());
743                            }
744                            join.let_in(id_gen, |id_gen, get_join| {
745                                let mut result = get_join.clone();
746                                if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
747                                    kind
748                                {
749                                    let left_outer = get_left.clone().anti_lookup::<PlanError>(
750                                        id_gen,
751                                        get_join.clone(),
752                                        rt.into_iter()
753                                            .map(|typ| (Datum::Null, typ.scalar_type))
754                                            .collect(),
755                                    )?;
756                                    result = result.union(left_outer);
757                                }
758                                if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
759                                    let right_outer = get_right
760                                        .clone()
761                                        .anti_lookup::<PlanError>(
762                                            id_gen,
763                                            get_join
764                                                // need to swap left and right to make the anti_lookup work
765                                                .project(
766                                                    (0..oa)
767                                                        .chain((oa + la)..(oa + la + ra))
768                                                        .chain((oa)..(oa + la))
769                                                        .collect(),
770                                                ),
771                                            lt.into_iter()
772                                                .map(|typ| (Datum::Null, typ.scalar_type))
773                                                .collect(),
774                                        )?
775                                        // swap left and right back again
776                                        .project(
777                                            (0..oa)
778                                                .chain((oa + ra)..(oa + ra + la))
779                                                .chain((oa)..(oa + ra))
780                                                .collect(),
781                                        );
782                                    result = result.union(right_outer);
783                                }
784                                Ok::<MirRelationExpr, PlanError>(result)
785                            })
786                        })
787                    })?
788                }
789                Union { base, inputs } => {
790                    // Union is uncomplicated.
791                    SR::Union {
792                        base: Box::new(base.applied_to(
793                            id_gen,
794                            get_outer.clone(),
795                            col_map,
796                            cte_map,
797                            context,
798                        )?),
799                        inputs: inputs
800                            .into_iter()
801                            .map(|input| {
802                                input.applied_to(
803                                    id_gen,
804                                    get_outer.clone(),
805                                    col_map,
806                                    cte_map,
807                                    context,
808                                )
809                            })
810                            .collect::<Result<Vec<_>, _>>()?,
811                    }
812                }
813                Reduce {
814                    input,
815                    group_key,
816                    aggregates,
817                    expected_group_size,
818                } => {
819                    // Reduce may contain expressions with correlated subqueries.
820                    // In addition, here an empty reduction key signifies that we need to supply default values
821                    // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
822                    let mut input =
823                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
824                    let applied_group_key = (0..get_outer.arity())
825                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
826                        .collect();
827                    let applied_aggregates = aggregates
828                        .into_iter()
829                        .map(|aggregate| {
830                            aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
831                        })
832                        .collect::<Result<Vec<_>, _>>()?;
833                    let input_type = input.typ();
834                    let default = applied_aggregates
835                        .iter()
836                        .map(|agg| {
837                            (
838                                agg.func.default(),
839                                agg.typ(&input_type.column_types).scalar_type,
840                            )
841                        })
842                        .collect();
843                    // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
844                    let mut reduced =
845                        input.reduce(applied_group_key, applied_aggregates, expected_group_size);
846
847                    // Introduce default values in the case the group key is empty.
848                    if group_key.is_empty() {
849                        reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
850                    }
851                    reduced
852                }
853                Distinct { input } => {
854                    // Distinct is uncomplicated.
855                    input
856                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
857                        .distinct()
858                }
859                TopK {
860                    input,
861                    group_key,
862                    order_key,
863                    limit,
864                    offset,
865                    expected_group_size,
866                } => {
867                    // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
868                    let mut input =
869                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
870                    let mut applied_group_key: Vec<_> = (0..get_outer.arity())
871                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
872                        .collect();
873                    let applied_order_key = order_key
874                        .iter()
875                        .map(|column_order| ColumnOrder {
876                            column: column_order.column + get_outer.arity(),
877                            desc: column_order.desc,
878                            nulls_last: column_order.nulls_last,
879                        })
880                        .collect();
881
882                    let old_arity = input.arity();
883
884                    // Lower `limit`, which may introduce new columns if is a correlated subquery.
885                    let mut limit_mir = None;
886                    if let Some(limit) = limit {
887                        limit_mir = Some(
888                            limit
889                                .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
890                        );
891                    }
892
893                    let new_arity = input.arity();
894                    // Extend the key to contain any new columns.
895                    applied_group_key.extend(old_arity..new_arity);
896
897                    let offset = offset
898                        .try_into_literal_int64()
899                        .expect("Should be a Literal by this time")
900                        .try_into()
901                        .expect("Should have checked non-negativity of OFFSET clause already");
902                    let mut result = input.top_k(
903                        applied_group_key,
904                        applied_order_key,
905                        limit_mir,
906                        offset,
907                        expected_group_size,
908                    );
909
910                    // If new columns were added for `limit` we must remove them.
911                    if old_arity != new_arity {
912                        result = result.project((0..old_arity).collect());
913                    }
914
915                    result
916                }
917                Negate { input } => {
918                    // Negate is uncomplicated.
919                    input
920                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
921                        .negate()
922                }
923                Threshold { input } => {
924                    // Threshold is uncomplicated.
925                    input
926                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
927                        .threshold()
928                }
929            })
930        })
931    }
932}
933
934impl HirScalarExpr {
935    /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
936    ///
937    /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
938    /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
939    /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
940    /// documented in more detail closer to their logic.
941    ///
942    /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
943    /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
944    /// each of these references can be found.
945    fn applied_to(
946        self,
947        id_gen: &mut mz_ore::id_gen::IdGen,
948        col_map: &ColumnMap,
949        cte_map: &mut CteMap,
950        inner: &mut MirRelationExpr,
951        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
952        context: &Context,
953    ) -> Result<MirScalarExpr, PlanError> {
954        maybe_grow(|| {
955            use MirScalarExpr as SS;
956
957            use HirScalarExpr::*;
958
959            if let Some(subquery_map) = subquery_map {
960                if let Some(col) = subquery_map.get(&self) {
961                    return Ok(SS::column(*col));
962                }
963            }
964
965            Ok::<MirScalarExpr, PlanError>(match self {
966                Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
967                Literal(row, typ, _name) => SS::Literal(Ok(row), typ),
968                Parameter(_, _name) => {
969                    panic!("cannot decorrelate expression with unbound parameters")
970                }
971                CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
972                CallUnary {
973                    func,
974                    expr,
975                    name: _,
976                } => SS::CallUnary {
977                    func,
978                    expr: Box::new(expr.applied_to(
979                        id_gen,
980                        col_map,
981                        cte_map,
982                        inner,
983                        subquery_map,
984                        context,
985                    )?),
986                },
987                CallBinary {
988                    func,
989                    expr1,
990                    expr2,
991                    name: _,
992                } => SS::CallBinary {
993                    func,
994                    expr1: Box::new(expr1.applied_to(
995                        id_gen,
996                        col_map,
997                        cte_map,
998                        inner,
999                        subquery_map,
1000                        context,
1001                    )?),
1002                    expr2: Box::new(expr2.applied_to(
1003                        id_gen,
1004                        col_map,
1005                        cte_map,
1006                        inner,
1007                        subquery_map,
1008                        context,
1009                    )?),
1010                },
1011                CallVariadic {
1012                    func,
1013                    exprs,
1014                    name: _,
1015                } => SS::CallVariadic {
1016                    func,
1017                    exprs: exprs
1018                        .into_iter()
1019                        .map(|expr| {
1020                            expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1021                        })
1022                        .collect::<Result<Vec<_>, _>>()?,
1023                },
1024                If {
1025                    cond,
1026                    then,
1027                    els,
1028                    name,
1029                } => {
1030                    // The `If` case is complicated by the fact that we do not want to
1031                    // apply the `then` or `else` logic to tuples that respectively do
1032                    // not or do pass the `cond` test. Our strategy is to independently
1033                    // decorrelate the `then` and `else` logic, and apply each to tuples
1034                    // that respectively pass and do not pass the `cond` logic (which is
1035                    // executed, and so decorrelated, for all tuples).
1036                    //
1037                    // Informally, we turn the `if` statement into:
1038                    //
1039                    //   let then_case = inner.filter(cond).map(then);
1040                    //   let else_case = inner.filter(!cond).map(else);
1041                    //   return then_case.concat(else_case);
1042                    //
1043                    // We only require this if either expression would result in any
1044                    // computation beyond the expr itself, which we will interpret as
1045                    // "introduces additional columns". In the absence of correlation,
1046                    // we should just retain a `ScalarExpr::If` expression; the inverse
1047                    // transformation as above is complicated to recover after the fact,
1048                    // and we would benefit from not introducing the complexity.
1049
1050                    let inner_arity = inner.arity();
1051                    let cond_expr =
1052                        cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1053
1054                    // Defensive copies, in case we mangle these in decorrelation.
1055                    let inner_clone = inner.clone();
1056                    let then_clone = then.clone();
1057                    let else_clone = els.clone();
1058
1059                    let cond_arity = inner.arity();
1060                    let then_expr =
1061                        then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1062                    let else_expr =
1063                        els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1064
1065                    if cond_arity == inner.arity() {
1066                        // If no additional columns were added, we simply return the
1067                        // `If` variant with the updated expressions.
1068                        SS::If {
1069                            cond: Box::new(cond_expr),
1070                            then: Box::new(then_expr),
1071                            els: Box::new(else_expr),
1072                        }
1073                    } else {
1074                        // If columns were added, we need a more careful approach, as
1075                        // described above. First, we need to de-correlate each of
1076                        // the two expressions independently, and apply their cases
1077                        // as `MirRelationExpr::Map` operations.
1078
1079                        *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1080                            // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1081                            let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1082                            let then_expr = then_clone.applied_to(
1083                                id_gen,
1084                                col_map,
1085                                cte_map,
1086                                &mut then_inner,
1087                                subquery_map,
1088                                context,
1089                            )?;
1090                            let then_arity = then_inner.arity();
1091                            then_inner = then_inner
1092                                .map_one(then_expr)
1093                                .project((0..inner_arity).chain(Some(then_arity)).collect());
1094
1095                            // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1096                            let mut else_inner = get_inner.filter(vec![SS::CallVariadic {
1097                                func: mz_expr::VariadicFunc::Or,
1098                                exprs: vec![
1099                                    cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1100                                    cond_expr.clone().call_is_null(),
1101                                ],
1102                            }]);
1103                            let else_expr = else_clone.applied_to(
1104                                id_gen,
1105                                col_map,
1106                                cte_map,
1107                                &mut else_inner,
1108                                subquery_map,
1109                                context,
1110                            )?;
1111                            let else_arity = else_inner.arity();
1112                            else_inner = else_inner
1113                                .map_one(else_expr)
1114                                .project((0..inner_arity).chain(Some(else_arity)).collect());
1115
1116                            // concatenate the two results.
1117                            Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1118                        })?;
1119
1120                        SS::Column(inner_arity, name)
1121                    }
1122                }
1123
1124                // Subqueries!
1125                // These are surprisingly subtle. Things to be careful of:
1126
1127                // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1128                // * change the row counts of the outer query
1129                // * accidentally compute its own value using the row counts of the outer query
1130                // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1131                // query and then join the answers back on to the original rows of the outer query.
1132
1133                // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1134                // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1135                Exists(expr, name) => {
1136                    let apply_requires_distinct_outer = true;
1137                    *inner = apply_existential_subquery(
1138                        id_gen,
1139                        inner.take_dangerous(),
1140                        col_map,
1141                        cte_map,
1142                        *expr,
1143                        apply_requires_distinct_outer,
1144                        context,
1145                    )?;
1146                    SS::Column(inner.arity() - 1, name)
1147                }
1148
1149                Select(expr, name) => {
1150                    let apply_requires_distinct_outer = true;
1151                    *inner = apply_scalar_subquery(
1152                        id_gen,
1153                        inner.take_dangerous(),
1154                        col_map,
1155                        cte_map,
1156                        *expr,
1157                        apply_requires_distinct_outer,
1158                        context,
1159                    )?;
1160                    SS::Column(inner.arity() - 1, name)
1161                }
1162                Windowing(expr, _name) => {
1163                    let partition_by = expr.partition_by;
1164                    let order_by = expr.order_by;
1165
1166                    // argument lowering for scalar window functions
1167                    // (We need to specify the & _ in the arguments because of this problem:
1168                    // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1169                    let scalar_lower_args =
1170                        |_id_gen: &mut _,
1171                         _col_map: &_,
1172                         _cte_map: &mut _,
1173                         _get_inner: &mut _,
1174                         _subquery_map: &Option<&_>,
1175                         order_by_mir: Vec<MirScalarExpr>,
1176                         original_row_record,
1177                         original_row_record_type: SqlScalarType| {
1178                            let agg_input = MirScalarExpr::CallVariadic {
1179                                func: mz_expr::VariadicFunc::ListCreate {
1180                                    elem_type: original_row_record_type.clone(),
1181                                },
1182                                exprs: vec![original_row_record],
1183                            };
1184                            let mut agg_input = vec![agg_input];
1185                            agg_input.extend(order_by_mir.clone());
1186                            let agg_input = MirScalarExpr::CallVariadic {
1187                                func: mz_expr::VariadicFunc::RecordCreate {
1188                                    field_names: (0..agg_input.len())
1189                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1190                                        .collect_vec(),
1191                                },
1192                                exprs: agg_input,
1193                            };
1194                            let list_type = SqlScalarType::List {
1195                                element_type: Box::new(original_row_record_type),
1196                                custom_id: None,
1197                            };
1198                            let agg_input_type = SqlScalarType::Record {
1199                                fields: std::iter::once(&list_type)
1200                                    .map(|t| {
1201                                        (
1202                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1203                                            t.clone().nullable(false),
1204                                        )
1205                                    })
1206                                    .collect(),
1207                                custom_id: None,
1208                            }
1209                            .nullable(false);
1210
1211                            Ok((agg_input, agg_input_type))
1212                        };
1213
1214                    // argument lowering for value window functions and aggregate window functions
1215                    let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1216                        |id_gen: &mut _,
1217                         col_map: &_,
1218                         cte_map: &mut _,
1219                         get_inner: &mut _,
1220                         subquery_map: &Option<&_>,
1221                         order_by_mir: Vec<MirScalarExpr>,
1222                         original_row_record,
1223                         original_row_record_type| {
1224                            // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1225
1226                            // Compute the encoded args for all rows
1227                            let mir_encoded_args = hir_encoded_args.applied_to(
1228                                id_gen,
1229                                col_map,
1230                                cte_map,
1231                                get_inner,
1232                                subquery_map,
1233                                context,
1234                            )?;
1235                            let mir_encoded_args_type = mir_encoded_args
1236                                .typ(&get_inner.typ().column_types)
1237                                .scalar_type;
1238
1239                            // Build a new record that has two fields:
1240                            // 1. the original row in a record
1241                            // 2. the encoded args (which can be either a single value, or a record
1242                            //    if the window function has multiple arguments, such as `lag`)
1243                            let fn_input_record_fields: Box<[_]> =
1244                                [original_row_record_type, mir_encoded_args_type]
1245                                    .iter()
1246                                    .map(|t| {
1247                                        (
1248                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1249                                            t.clone().nullable(false),
1250                                        )
1251                                    })
1252                                    .collect();
1253                            let fn_input_record = MirScalarExpr::CallVariadic {
1254                                func: mz_expr::VariadicFunc::RecordCreate {
1255                                    field_names: fn_input_record_fields
1256                                        .iter()
1257                                        .map(|(n, _)| n.clone())
1258                                        .collect_vec(),
1259                                },
1260                                exprs: vec![original_row_record, mir_encoded_args],
1261                            };
1262                            let fn_input_record_type = SqlScalarType::Record {
1263                                fields: fn_input_record_fields,
1264                                custom_id: None,
1265                            }
1266                            .nullable(false);
1267
1268                            // Build a new record with the record above + the ORDER BY exprs
1269                            // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1270                            let mut agg_input = vec![fn_input_record];
1271                            agg_input.extend(order_by_mir.clone());
1272                            let agg_input = MirScalarExpr::CallVariadic {
1273                                func: mz_expr::VariadicFunc::RecordCreate {
1274                                    field_names: (0..agg_input.len())
1275                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1276                                        .collect_vec(),
1277                                },
1278                                exprs: agg_input,
1279                            };
1280
1281                            let agg_input_type = SqlScalarType::Record {
1282                                fields: [(
1283                                    ColumnName::from(UNKNOWN_COLUMN_NAME),
1284                                    fn_input_record_type.nullable(false),
1285                                )]
1286                                .into(),
1287                                custom_id: None,
1288                            }
1289                            .nullable(false);
1290
1291                            Ok((agg_input, agg_input_type))
1292                        }
1293                    };
1294
1295                    match expr.func {
1296                        WindowExprType::Scalar(scalar_window_expr) => {
1297                            let mir_aggr_func = scalar_window_expr.into_expr();
1298                            Self::window_func_applied_to(
1299                                id_gen,
1300                                col_map,
1301                                cte_map,
1302                                inner,
1303                                subquery_map,
1304                                partition_by,
1305                                order_by,
1306                                mir_aggr_func,
1307                                scalar_lower_args,
1308                                context,
1309                            )?
1310                        }
1311                        WindowExprType::Value(value_window_expr) => {
1312                            let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1313
1314                            Self::window_func_applied_to(
1315                                id_gen,
1316                                col_map,
1317                                cte_map,
1318                                inner,
1319                                subquery_map,
1320                                partition_by,
1321                                order_by,
1322                                mir_aggr_func,
1323                                value_or_aggr_lower_args(hir_encoded_args),
1324                                context,
1325                            )?
1326                        }
1327                        WindowExprType::Aggregate(aggr_window_expr) => {
1328                            let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1329
1330                            Self::window_func_applied_to(
1331                                id_gen,
1332                                col_map,
1333                                cte_map,
1334                                inner,
1335                                subquery_map,
1336                                partition_by,
1337                                order_by,
1338                                mir_aggr_func,
1339                                value_or_aggr_lower_args(hir_encoded_args),
1340                                context,
1341                            )?
1342                        }
1343                    }
1344                }
1345            })
1346        })
1347    }
1348
1349    fn window_func_applied_to<F>(
1350        id_gen: &mut mz_ore::id_gen::IdGen,
1351        col_map: &ColumnMap,
1352        cte_map: &mut CteMap,
1353        inner: &mut MirRelationExpr,
1354        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1355        partition_by: Vec<HirScalarExpr>,
1356        order_by: Vec<HirScalarExpr>,
1357        mir_aggr_func: AggregateFunc,
1358        lower_args: F,
1359        context: &Context,
1360    ) -> Result<MirScalarExpr, PlanError>
1361    where
1362        F: FnOnce(
1363            &mut mz_ore::id_gen::IdGen,
1364            &ColumnMap,
1365            &mut CteMap,
1366            &mut MirRelationExpr,
1367            &Option<&BTreeMap<HirScalarExpr, usize>>,
1368            Vec<MirScalarExpr>,
1369            MirScalarExpr,
1370            SqlScalarType,
1371        ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1372    {
1373        // Example MIRs for a window function (specifically, a window aggregation):
1374        //
1375        // CREATE TABLE t7(x INT, y INT);
1376        //
1377        // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1378        //
1379        // Decorrelated Plan
1380        // Project (#3)
1381        //   Map (#2)
1382        //     Project (#3..=#5)
1383        //       Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1384        //         FlatMap unnest_list(#1)
1385        //           Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1386        //             Map ((#0 + #1))
1387        //               CrossJoin
1388        //                 Constant
1389        //                   - ()
1390        //                 Get materialize.public.t7
1391        //
1392        // The same query after optimizations:
1393        //
1394        // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1395        //
1396        // Optimized Plan
1397        // Explained Query:
1398        //   Project (#2)
1399        //     Map (record_get[0](#1))
1400        //       FlatMap unnest_list(#0)
1401        //         Project (#1)
1402        //           Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1403        //             ReadStorage materialize.public.t7
1404        //
1405        // The `row(row(row(...), ...), ...)` stuff means the following:
1406        // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1407        //   - The <arguments to window function> can be either a single value or itself a
1408        //     `row` if there are multiple arguments.
1409        //   - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1410        //     ORDER BY columns.
1411        //   - The <original row> currently always captures the entire original row. This should
1412        //     improve when we make `ProjectionPushdown` smarter, see
1413        //     https://github.com/MaterializeInc/database-issues/issues/5090
1414        //
1415        // TODO:
1416        // We should probably introduce some dedicated Datum constructor functions instead of `row`
1417        // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1418        // might even introduce dedicated Datum enum variants, so that the rendering code also
1419        // becomes more readable (and possibly slightly more performant).
1420
1421        *inner = inner
1422            .take_dangerous()
1423            .let_in(id_gen, |id_gen, mut get_inner| {
1424                let order_by_mir = order_by
1425                    .into_iter()
1426                    .map(|o| {
1427                        o.applied_to(
1428                            id_gen,
1429                            col_map,
1430                            cte_map,
1431                            &mut get_inner,
1432                            subquery_map,
1433                            context,
1434                        )
1435                    })
1436                    .collect::<Result<Vec<_>, _>>()?;
1437
1438                // Record input arity here so that any group_keys that need to mutate get_inner
1439                // don't add those columns to the aggregate input.
1440                let input_type = get_inner.typ();
1441                let input_arity = input_type.arity();
1442                // The reduction that computes the window function must be keyed on the columns
1443                // from the outer context, plus the expressions in the partition key. The current
1444                // subquery will be 'executed' for every distinct row from the outer context so
1445                // by putting the outer columns in the grouping key we isolate each re-execution.
1446                let mut group_key = col_map
1447                    .inner
1448                    .iter()
1449                    .map(|(_, outer_col)| *outer_col)
1450                    .sorted()
1451                    .collect_vec();
1452                for p in partition_by {
1453                    let key = p.applied_to(
1454                        id_gen,
1455                        col_map,
1456                        cte_map,
1457                        &mut get_inner,
1458                        subquery_map,
1459                        context,
1460                    )?;
1461                    if let MirScalarExpr::Column(c, _name) = key {
1462                        group_key.push(c);
1463                    } else {
1464                        get_inner = get_inner.map_one(key);
1465                        group_key.push(get_inner.arity() - 1);
1466                    }
1467                }
1468
1469                get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1470                    // Original columns of the relation
1471                    let fields: Box<_> = input_type
1472                        .column_types
1473                        .iter()
1474                        .take(input_arity)
1475                        .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1476                        .collect();
1477
1478                    // Original row made into a record
1479                    let original_row_record = MirScalarExpr::CallVariadic {
1480                        func: mz_expr::VariadicFunc::RecordCreate {
1481                            field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1482                        },
1483                        exprs: (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1484                    };
1485                    let original_row_record_type = SqlScalarType::Record {
1486                        fields,
1487                        custom_id: None,
1488                    };
1489
1490                    let (agg_input, agg_input_type) = lower_args(
1491                        id_gen,
1492                        col_map,
1493                        cte_map,
1494                        &mut get_inner,
1495                        subquery_map,
1496                        order_by_mir,
1497                        original_row_record,
1498                        original_row_record_type,
1499                    )?;
1500
1501                    let aggregate = mz_expr::AggregateExpr {
1502                        func: mir_aggr_func,
1503                        expr: agg_input,
1504                        distinct: false,
1505                    };
1506
1507                    // Actually call reduce with the window function
1508                    // The output of the aggregation function should be a list of tuples that has
1509                    // the result in the first position, and the original row in the second position
1510                    let mut reduce = get_inner
1511                        .reduce(group_key.clone(), vec![aggregate.clone()], None)
1512                        .flat_map(
1513                            mz_expr::TableFunc::UnnestList {
1514                                el_typ: aggregate
1515                                    .func
1516                                    .output_type(agg_input_type)
1517                                    .scalar_type
1518                                    .unwrap_list_element_type()
1519                                    .clone(),
1520                            },
1521                            vec![MirScalarExpr::column(group_key.len())],
1522                        );
1523                    let record_col = reduce.arity() - 1;
1524
1525                    // Unpack the record output by the window function
1526                    for c in 0..input_arity {
1527                        reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1528                            func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1529                            expr: Box::new(MirScalarExpr::CallUnary {
1530                                func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1531                                expr: Box::new(MirScalarExpr::column(record_col)),
1532                            }),
1533                        });
1534                    }
1535
1536                    // Append the column with the result of the window function.
1537                    reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1538                        func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1539                        expr: Box::new(MirScalarExpr::column(record_col)),
1540                    });
1541
1542                    let agg_col = record_col + 1 + input_arity;
1543                    Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1544                })
1545            })?;
1546        Ok(MirScalarExpr::column(inner.arity() - 1))
1547    }
1548
1549    /// Applies the subqueries in the given list of scalar expressions to every distinct
1550    /// value of the given relation and returns a join of the given relation with all
1551    /// the subqueries found, and the mapping of scalar expressions with columns projected
1552    /// by the returned join that will hold their results.
1553    fn lower_subqueries(
1554        exprs: &[Self],
1555        id_gen: &mut mz_ore::id_gen::IdGen,
1556        col_map: &ColumnMap,
1557        cte_map: &mut CteMap,
1558        inner: MirRelationExpr,
1559        context: &Context,
1560    ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1561        let mut subquery_map = BTreeMap::new();
1562        let output = inner.let_in(id_gen, |id_gen, get_inner| {
1563            let mut subqueries = Vec::new();
1564            let distinct_inner = get_inner.clone().distinct();
1565            for expr in exprs.iter() {
1566                expr.visit_pre_post(
1567                    &mut |e| match e {
1568                        // For simplicity, subqueries within a conditional statement will be
1569                        // lowered when lowering the conditional expression.
1570                        HirScalarExpr::If { .. } => Some(vec![]),
1571                        _ => None,
1572                    },
1573                    &mut |e| match e {
1574                        HirScalarExpr::Select(expr, _name) => {
1575                            let apply_requires_distinct_outer = false;
1576                            let subquery = apply_scalar_subquery(
1577                                id_gen,
1578                                distinct_inner.clone(),
1579                                col_map,
1580                                cte_map,
1581                                (**expr).clone(),
1582                                apply_requires_distinct_outer,
1583                                context,
1584                            )
1585                            .unwrap();
1586
1587                            subqueries.push((e.clone(), subquery));
1588                        }
1589                        HirScalarExpr::Exists(expr, _name) => {
1590                            let apply_requires_distinct_outer = false;
1591                            let subquery = apply_existential_subquery(
1592                                id_gen,
1593                                distinct_inner.clone(),
1594                                col_map,
1595                                cte_map,
1596                                (**expr).clone(),
1597                                apply_requires_distinct_outer,
1598                                context,
1599                            )
1600                            .unwrap();
1601                            subqueries.push((e.clone(), subquery));
1602                        }
1603                        _ => {}
1604                    },
1605                )?;
1606            }
1607
1608            if subqueries.is_empty() {
1609                Ok::<MirRelationExpr, PlanError>(get_inner)
1610            } else {
1611                let inner_arity = get_inner.arity();
1612                let mut total_arity = inner_arity;
1613                let mut join_inputs = vec![get_inner];
1614                let mut join_input_arities = vec![inner_arity];
1615                for (expr, subquery) in subqueries.into_iter() {
1616                    // Avoid lowering duplicated subqueries
1617                    if !subquery_map.contains_key(&expr) {
1618                        let subquery_arity = subquery.arity();
1619                        assert_eq!(subquery_arity, inner_arity + 1);
1620                        join_inputs.push(subquery);
1621                        join_input_arities.push(subquery_arity);
1622                        total_arity += subquery_arity;
1623
1624                        // Column with the value of the subquery
1625                        subquery_map.insert(expr, total_arity - 1);
1626                    }
1627                }
1628                // Each subquery projects all the columns of the outer context (distinct_inner)
1629                // plus 1 column, containing the result of the subquery. Those columns must be
1630                // joined with the outer/main relation (get_inner).
1631                let input_mapper =
1632                    mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1633                let equivalences = (0..inner_arity)
1634                    .map(|col| {
1635                        join_inputs
1636                            .iter()
1637                            .enumerate()
1638                            .map(|(input, _)| {
1639                                MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1640                            })
1641                            .collect_vec()
1642                    })
1643                    .collect_vec();
1644                Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1645            }
1646        })?;
1647        Ok((output, subquery_map))
1648    }
1649
1650    /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1651    ///
1652    /// Returns an _internal_ error if the expression contains
1653    /// - a subquery
1654    /// - a column reference to an outer level
1655    /// - a parameter
1656    /// - a window function call
1657    ///
1658    /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1659    pub fn lower_uncorrelated(self) -> Result<MirScalarExpr, PlanError> {
1660        use MirScalarExpr as SS;
1661
1662        use HirScalarExpr::*;
1663
1664        Ok(match self {
1665            Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1666            Literal(datum, typ, _name) => SS::Literal(Ok(datum), typ),
1667            CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1668            CallUnary {
1669                func,
1670                expr,
1671                name: _,
1672            } => SS::CallUnary {
1673                func,
1674                expr: Box::new(expr.lower_uncorrelated()?),
1675            },
1676            CallBinary {
1677                func,
1678                expr1,
1679                expr2,
1680                name: _,
1681            } => SS::CallBinary {
1682                func,
1683                expr1: Box::new(expr1.lower_uncorrelated()?),
1684                expr2: Box::new(expr2.lower_uncorrelated()?),
1685            },
1686            CallVariadic {
1687                func,
1688                exprs,
1689                name: _,
1690            } => SS::CallVariadic {
1691                func,
1692                exprs: exprs
1693                    .into_iter()
1694                    .map(|expr| expr.lower_uncorrelated())
1695                    .collect::<Result<_, _>>()?,
1696            },
1697            If {
1698                cond,
1699                then,
1700                els,
1701                name: _,
1702            } => SS::If {
1703                cond: Box::new(cond.lower_uncorrelated()?),
1704                then: Box::new(then.lower_uncorrelated()?),
1705                els: Box::new(els.lower_uncorrelated()?),
1706            },
1707            Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1708                sql_bail!(
1709                    "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1710                    self
1711                );
1712            }
1713        })
1714    }
1715}
1716
1717/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1718/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1719/// will, in effect, be executed once for every distinct row in `outer`, and the
1720/// results will be joined with `outer`. Note that columns in `outer` that are
1721/// not depended upon by `inner` are thrown away before the distinct, so that we
1722/// don't perform needless computation of `inner`.
1723///
1724/// `branch` will inspect the contents of `inner` to determine whether `inner`
1725/// is not multiplicity sensitive (roughly, contains only maps, filters,
1726/// projections, and calls to table functions). If it is not multiplicity
1727/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1728/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1729/// operations that were not present in `inner`, then set
1730/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1731/// that distinctifies `outer`.
1732///
1733/// The caller must supply the `apply` function that applies the rewritten
1734/// `inner` to `outer`.
1735fn branch<F>(
1736    id_gen: &mut mz_ore::id_gen::IdGen,
1737    outer: MirRelationExpr,
1738    col_map: &ColumnMap,
1739    cte_map: &mut CteMap,
1740    inner: HirRelationExpr,
1741    apply_requires_distinct_outer: bool,
1742    context: &Context,
1743    apply: F,
1744) -> Result<MirRelationExpr, PlanError>
1745where
1746    F: FnOnce(
1747        &mut mz_ore::id_gen::IdGen,
1748        HirRelationExpr,
1749        MirRelationExpr,
1750        &ColumnMap,
1751        &mut CteMap,
1752        &Context,
1753    ) -> Result<MirRelationExpr, PlanError>,
1754{
1755    // TODO: It would be nice to have a version of this code w/o optimizations,
1756    // at the least for purposes of understanding. It was difficult for one reader
1757    // to understand the required properties of `outer` and `col_map`.
1758
1759    // If the inner expression is sufficiently simple, it is safe to apply it
1760    // *directly* to outer, rather than applying it to the distinctified key
1761    // (see below).
1762    //
1763    // As an example, consider the following two queries:
1764    //
1765    //     CREATE TABLE t (a int, b int);
1766    //     SELECT a, series FROM t, generate_series(1, t.b) series;
1767    //
1768    // The "simple" path for the `SELECT` yields
1769    //
1770    //     %0 =
1771    //     | Get t
1772    //     | FlatMap generate_series(1, #1)
1773    //
1774    // while the non-simple path yields:
1775    //
1776    //    %0 =
1777    //    | Get t
1778    //
1779    //    %1 =
1780    //    | Get t
1781    //    | Distinct group=(#1)
1782    //    | FlatMap generate_series(1, #0)
1783    //
1784    //    %2 =
1785    //    | LeftJoin %1 %2 (= #1 #2)
1786    //
1787    // There is a tradeoff here: the simple plan is stateless, but the non-
1788    // simple plan may do (much) less computation if there are only a few
1789    // distinct values of `t.b`.
1790    //
1791    // We apply a very simple heuristic here and take the simple path if `inner`
1792    // contains only maps, filters, projections, and calls to table functions.
1793    // The intuition is that straightforward usage of table functions should
1794    // take the simple path, while everything else should not. (In theory we
1795    // think this transformation is valid as long as `inner` does not contain a
1796    // Reduce, Distinct, or TopK node, but it is not always an optimization in
1797    // the general case.)
1798    //
1799    // TODO(benesch): this should all be handled by a proper optimizer, but
1800    // detecting the moment of decorrelation in the optimizer right now is too
1801    // hard.
1802    let mut is_simple = true;
1803    #[allow(deprecated)]
1804    inner.visit(0, &mut |expr, _| match expr {
1805        HirRelationExpr::Constant { .. }
1806        | HirRelationExpr::Project { .. }
1807        | HirRelationExpr::Map { .. }
1808        | HirRelationExpr::Filter { .. }
1809        | HirRelationExpr::CallTable { .. } => (),
1810        _ => is_simple = false,
1811    });
1812    if is_simple && !apply_requires_distinct_outer {
1813        let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1814        return outer.let_in(id_gen, |id_gen, get_outer| {
1815            apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1816        });
1817    }
1818
1819    // The key consists of the columns from the outer expression upon which the
1820    // inner relation depends. We discover these dependencies by walking the
1821    // inner relation expression and looking for column references whose level
1822    // escapes inner.
1823    //
1824    // At the end of this process, `key` contains the decorrelated position of
1825    // each outer column, according to the passed-in `col_map`, and
1826    // `new_col_map` maps each outer column to its new ordinal position in key.
1827    let mut outer_cols = BTreeSet::new();
1828    #[allow(deprecated)]
1829    inner.visit_columns(0, &mut |depth, col| {
1830        // Test if the column reference escapes the subquery.
1831        if col.level > depth {
1832            outer_cols.insert(ColumnRef {
1833                level: col.level - depth,
1834                column: col.column,
1835            });
1836        }
1837    });
1838    // Collect all the outer columns referenced by any CTE referenced by
1839    // the inner relation.
1840    #[allow(deprecated)]
1841    inner.visit(0, &mut |e, _| match e {
1842        HirRelationExpr::Get {
1843            id: mz_expr::Id::Local(id),
1844            ..
1845        } => {
1846            if let Some(cte_desc) = cte_map.get(id) {
1847                let cte_outer_arity = cte_desc.outer_relation.arity();
1848                outer_cols.extend(
1849                    col_map
1850                        .inner
1851                        .iter()
1852                        .filter(|(_, position)| **position < cte_outer_arity)
1853                        .map(|(c, _)| {
1854                            // `col_map` maps column references to column positions in
1855                            // `outer`'s projection.
1856                            // `outer_cols` is meant to contain the external column
1857                            // references in `inner`.
1858                            // Since `inner` defines a new scope, any column reference
1859                            // in `col_map` is one level deeper when seen from within
1860                            // `inner`, hence the +1.
1861                            ColumnRef {
1862                                level: c.level + 1,
1863                                column: c.column,
1864                            }
1865                        }),
1866                );
1867            }
1868        }
1869        HirRelationExpr::Let { id, .. } => {
1870            // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1871            // we would need to remove the old CTE with the same ID temporarily while
1872            // traversing the definition of the new CTE under the same ID.
1873            assert!(!cte_map.contains_key(id));
1874        }
1875        _ => {}
1876    });
1877    let mut new_col_map = BTreeMap::new();
1878    let mut key = vec![];
1879    for col in outer_cols {
1880        new_col_map.insert(col, key.len());
1881        key.push(col_map.get(&ColumnRef {
1882            // Note: `outer_cols` contains the external column references within `inner`.
1883            // We must compensate for `inner`'s scope when translating column references
1884            // as seen within `inner` to column references as seen from `outer`'s context,
1885            // hence the -1.
1886            level: col.level - 1,
1887            column: col.column,
1888        }));
1889    }
1890    let new_col_map = ColumnMap::new(new_col_map);
1891    outer.let_in(id_gen, |id_gen, get_outer| {
1892        let keyed_outer = if key.is_empty() {
1893            // Don't depend on outer at all if the branch is not correlated,
1894            // which yields vastly better query plans. Note that this is a bit
1895            // weird in that the branch will be computed even if outer has no
1896            // rows, whereas if it had been correlated it would not (and *could*
1897            // not) have been computed if outer had no rows, but the callers of
1898            // this function don't mind these somewhat-weird semantics.
1899            MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![]))
1900        } else {
1901            get_outer.clone().distinct_by(key.clone())
1902        };
1903        keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1904            let oa = get_outer.arity();
1905            let branch = apply(
1906                id_gen,
1907                inner,
1908                get_keyed_outer,
1909                &new_col_map,
1910                cte_map,
1911                context,
1912            )?;
1913            let ba = branch.arity();
1914            let joined = MirRelationExpr::join(
1915                vec![get_outer.clone(), branch],
1916                key.iter()
1917                    .enumerate()
1918                    .map(|(i, &k)| vec![(0, k), (1, i)])
1919                    .collect(),
1920            )
1921            // throw away the right-hand copy of the key we just joined on
1922            .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1923            Ok(joined)
1924        })
1925    })
1926}
1927
1928fn apply_scalar_subquery(
1929    id_gen: &mut mz_ore::id_gen::IdGen,
1930    outer: MirRelationExpr,
1931    col_map: &ColumnMap,
1932    cte_map: &mut CteMap,
1933    scalar_subquery: HirRelationExpr,
1934    apply_requires_distinct_outer: bool,
1935    context: &Context,
1936) -> Result<MirRelationExpr, PlanError> {
1937    branch(
1938        id_gen,
1939        outer,
1940        col_map,
1941        cte_map,
1942        scalar_subquery,
1943        apply_requires_distinct_outer,
1944        context,
1945        |id_gen, expr, get_inner, col_map, cte_map, context| {
1946            // compute for every row in get_inner
1947            let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1948            let col_type = select.typ().column_types.into_last();
1949
1950            let inner_arity = get_inner.arity();
1951            // We must determine a count for each `get_inner` prefix,
1952            // and report an error if that count exceeds one.
1953            let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1954                // Count for each `get_inner` prefix.
1955                let counts = get_select.clone().reduce(
1956                    (0..inner_arity).collect::<Vec<_>>(),
1957                    vec![mz_expr::AggregateExpr {
1958                        func: mz_expr::AggregateFunc::Count,
1959                        expr: MirScalarExpr::literal_true(),
1960                        distinct: false,
1961                    }],
1962                    None,
1963                );
1964
1965                let use_guard = context.config.enable_guard_subquery_tablefunc;
1966
1967                // Errors should result from counts > 1.
1968                let errors = if use_guard {
1969                    counts
1970                        .flat_map(
1971                            mz_expr::TableFunc::GuardSubquerySize {
1972                                column_type: col_type.clone().scalar_type,
1973                            },
1974                            vec![MirScalarExpr::column(inner_arity)],
1975                        )
1976                        .project(
1977                            (0..inner_arity)
1978                                .chain(Some(inner_arity + 1))
1979                                .collect::<Vec<_>>(),
1980                        )
1981                } else {
1982                    counts
1983                        .filter(vec![MirScalarExpr::column(inner_arity).call_binary(
1984                            MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
1985                            func::Gt,
1986                        )])
1987                        .project((0..inner_arity).collect::<Vec<_>>())
1988                        .map_one(MirScalarExpr::literal(
1989                            Err(mz_expr::EvalError::MultipleRowsFromSubquery),
1990                            col_type.clone().scalar_type,
1991                        ))
1992                };
1993                // Return `get_select` and any errors added in.
1994                Ok::<_, PlanError>(get_select.union(errors))
1995            })?;
1996            // append Null to anything that didn't return any rows
1997            let default = vec![(Datum::Null, col_type.scalar_type)];
1998            get_inner.lookup(id_gen, guarded, default)
1999        },
2000    )
2001}
2002
2003fn apply_existential_subquery(
2004    id_gen: &mut mz_ore::id_gen::IdGen,
2005    outer: MirRelationExpr,
2006    col_map: &ColumnMap,
2007    cte_map: &mut CteMap,
2008    subquery_expr: HirRelationExpr,
2009    apply_requires_distinct_outer: bool,
2010    context: &Context,
2011) -> Result<MirRelationExpr, PlanError> {
2012    branch(
2013        id_gen,
2014        outer,
2015        col_map,
2016        cte_map,
2017        subquery_expr,
2018        apply_requires_distinct_outer,
2019        context,
2020        |id_gen, expr, get_inner, col_map, cte_map, context| {
2021            let exists = expr
2022                // compute for every row in get_inner
2023                .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2024                // throw away actual values and just remember whether or not there were __any__ rows
2025                .distinct_by((0..get_inner.arity()).collect())
2026                // Append true to anything that returned any rows.
2027                .map(vec![MirScalarExpr::literal_true()]);
2028
2029            // append False to anything that didn't return any rows
2030            get_inner.lookup(id_gen, exists, vec![(Datum::False, SqlScalarType::Bool)])
2031        },
2032    )
2033}
2034
2035impl AggregateExpr {
2036    fn applied_to(
2037        self,
2038        id_gen: &mut mz_ore::id_gen::IdGen,
2039        col_map: &ColumnMap,
2040        cte_map: &mut CteMap,
2041        inner: &mut MirRelationExpr,
2042        context: &Context,
2043    ) -> Result<mz_expr::AggregateExpr, PlanError> {
2044        let AggregateExpr {
2045            func,
2046            expr,
2047            distinct,
2048        } = self;
2049
2050        Ok(mz_expr::AggregateExpr {
2051            func: func.into_expr(),
2052            expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2053            distinct,
2054        })
2055    }
2056}
2057
2058/// Attempts an efficient outer join, if `on` has equijoin structure.
2059///
2060/// Both `left` and `right` are decorrelated inputs.
2061///
2062/// The first `oa` columns correspond to an outer context: we should do the
2063/// outer join independently for each prefix. In the case that `on` contains
2064/// just some equality tests between columns of `left` and `right` and some
2065/// local predicates, we can employ a relatively simple plan.
2066///
2067/// The last `on_subquery_types.len()` columns correspond to results from
2068/// subqueries defined in the `on` clause - we treat those as theta-join
2069/// conditions that prohibit the use of the simple plan attempted here.
2070fn attempt_outer_equijoin(
2071    left: MirRelationExpr,
2072    right: MirRelationExpr,
2073    on: MirScalarExpr,
2074    on_subquery_types: Vec<SqlColumnType>,
2075    kind: JoinKind,
2076    oa: usize,
2077    id_gen: &mut mz_ore::id_gen::IdGen,
2078    context: &Context,
2079) -> Result<Option<MirRelationExpr>, PlanError> {
2080    // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2081    // predicates that reference subqueries as long as these subqueries don't
2082    // reference `left` and `right` at the same time.
2083    //
2084    // TODO(database-issues#6828): This code can be improved as follows:
2085    //
2086    // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2087    // 2. Use the canonicalized `on` predicate in the non-equijoin based
2088    //    lowering strategy.
2089    // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2090    // 4. Pass the classified `OnPredicates` as a parameter.
2091    // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2092    //
2093    // Steps (1 + 2) require further investigation because we might change the
2094    // error semantics in case the `on` predicate contains a literal error..
2095
2096    let l_type = left.typ();
2097    let r_type = right.typ();
2098    let la = l_type.column_types.len() - oa;
2099    let ra = r_type.column_types.len() - oa;
2100    let sa = on_subquery_types.len();
2101
2102    // The output type contains [outer, left, right, sa] attributes.
2103    let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2104    output_type.extend(l_type.column_types);
2105    output_type.extend(r_type.column_types.into_iter().skip(oa));
2106    output_type.extend(on_subquery_types);
2107
2108    // Generally healthy to do, but specifically `USING` conditions sometimes
2109    // put an `AND true` at the end of the `ON` condition.
2110    //
2111    // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2112    // However, in that case it's not clear that we won't see regressions if
2113    // `on` simplifies to a literal error.
2114    let mut on = vec![on];
2115    mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2116
2117    // Form the left and right types without the outer attributes.
2118    output_type.drain(0..oa);
2119    let lt = output_type.drain(0..la).collect_vec();
2120    let rt = output_type.drain(0..ra).collect_vec();
2121    assert!(output_type.len() == sa);
2122
2123    let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2124    if !on_predicates.is_equijoin(context) {
2125        return Ok(None);
2126    }
2127
2128    // If we've gotten this far, we can do the clever thing.
2129    // We'll want to use left and right multiple times
2130    let result = left.let_in(id_gen, |id_gen, get_left| {
2131        right.let_in(id_gen, |id_gen, get_right| {
2132            // TODO: we know that we can re-use the arrangements of left and right
2133            // needed for the inner join with each of the conditional outer joins.
2134            // It is not clear whether we should hint that, or just let the planner
2135            // and optimizer run and see what happens.
2136
2137            // We'll want the inner join (minus repeated columns)
2138            let join = MirRelationExpr::join(
2139                vec![get_left.clone(), get_right.clone()],
2140                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2141            )
2142            // remove those columns from `right` repeating the first `oa` columns.
2143            .project(
2144                (0..(oa + la))
2145                    .chain((oa + la + oa)..(oa + la + oa + ra))
2146                    .collect(),
2147            )
2148            // apply the filter constraints here, to ensure nulls are not matched.
2149            .filter(on);
2150
2151            // We'll want to re-use the results of the join multiple times.
2152            join.let_in(id_gen, |id_gen, get_join| {
2153                let mut result = get_join.clone();
2154
2155                // A collection of keys present in both left and right collections.
2156                let join_keys = on_predicates.join_keys();
2157                let both_keys_arity = join_keys.len();
2158                let both_keys = get_join.restrict(join_keys).distinct();
2159
2160                // The plan is now to determine the left and right rows matched in the
2161                // inner join, subtract them from left and right respectively, pad what
2162                // remains with nulls, and fold them in to `result`.
2163
2164                both_keys.let_in(id_gen, |_id_gen, get_both| {
2165                    if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2166                        // Rows in `left` matched in the inner equijoin. This is
2167                        // a semi-join between `left` and `both_keys`.
2168                        let left_present = MirRelationExpr::join_scalars(
2169                            vec![
2170                                get_left
2171                                    .clone()
2172                                    // Push local predicates.
2173                                    .filter(on_predicates.lhs()),
2174                                get_both.clone(),
2175                            ],
2176                            itertools::zip_eq(
2177                                on_predicates.eq_lhs(),
2178                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2179                            )
2180                            .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2181                            .collect(),
2182                        )
2183                        .project((0..(oa + la)).collect());
2184
2185                        // Determine the types of nulls to use as filler.
2186                        let right_fill = rt
2187                            .into_iter()
2188                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2189                            .collect();
2190
2191                        // Add to `result` absent elements, filled with typed nulls.
2192                        result = left_present
2193                            .negate()
2194                            .union(get_left.clone())
2195                            .map(right_fill)
2196                            .union(result);
2197                    }
2198
2199                    if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2200                        // Rows in `right` matched in the inner equijoin. This
2201                        // is a semi-join between `right` and `both_keys`.
2202                        let right_present = MirRelationExpr::join_scalars(
2203                            vec![
2204                                get_right
2205                                    .clone()
2206                                    // Push local predicates.
2207                                    .filter(on_predicates.rhs()),
2208                                get_both,
2209                            ],
2210                            itertools::zip_eq(
2211                                on_predicates.eq_rhs(),
2212                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2213                            )
2214                            .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2215                            .collect(),
2216                        )
2217                        .project((0..(oa + ra)).collect());
2218
2219                        // Determine the types of nulls to use as filler.
2220                        let left_fill = lt
2221                            .into_iter()
2222                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2223                            .collect();
2224
2225                        // Add to `result` absent elements, prepended with typed nulls.
2226                        result = right_present
2227                            .negate()
2228                            .union(get_right.clone())
2229                            .map(left_fill)
2230                            // Permute left fill before right values.
2231                            .project(
2232                                itertools::chain!(
2233                                    0..oa,                 // Preserve `outer`.
2234                                    oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2235                                    oa..oa + ra            // Decrement the next `ra` cols by `la`.
2236                                )
2237                                .collect(),
2238                            )
2239                            .union(result)
2240                    }
2241
2242                    Ok::<_, PlanError>(result)
2243                })
2244            })
2245        })
2246    })?;
2247    Ok(Some(result))
2248}
2249
2250/// A struct that represents the predicates in the `on` clause in a form
2251/// suitable for efficient planning outer joins with equijoin predicates.
2252struct OnPredicates {
2253    /// A store for classified `ON` predicates.
2254    ///
2255    /// Predicates that reference a single side are adjusted to assume an
2256    /// `outer × <side>` schema.
2257    predicates: Vec<OnPredicate>,
2258    /// Number of outer context columns.
2259    oa: usize,
2260}
2261
2262impl OnPredicates {
2263    const I_OUT: usize = 0; // outer context input position
2264    const I_LHS: usize = 1; // lhs input position
2265    const I_RHS: usize = 2; // rhs input position
2266    const I_SUB: usize = 3; // on subqueries input position
2267
2268    /// Classify the predicates in the `on` clause of an outer join.
2269    ///
2270    /// The other parameters are arities of the input parts:
2271    ///
2272    /// - `oa` is the arity of the `outer` context.
2273    /// - `la` is the arity of the `left` input.
2274    /// - `ra` is the arity of the `right` input.
2275    /// - `sa` is the arity of the `on` subqueries.
2276    ///
2277    /// The constructor assumes that:
2278    ///
2279    /// 1. The `on` parameter will be applied on a result that has the following
2280    ///    schema `outer × left × right × on_subqueries`.
2281    /// 2. The `on` parameter is already adjusted to assume that schema.
2282    /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2283    ///    MirScalarExpr` with `canonicalize_predicates`.
2284    fn new(
2285        oa: usize,
2286        la: usize,
2287        ra: usize,
2288        sa: usize,
2289        on: Vec<MirScalarExpr>,
2290        _context: &Context,
2291    ) -> Self {
2292        use mz_expr::BinaryFunc::Eq;
2293
2294        // Re-bind those locally for more compact pattern matching.
2295        const I_LHS: usize = OnPredicates::I_LHS;
2296        const I_RHS: usize = OnPredicates::I_RHS;
2297
2298        // Self parameters.
2299        let mut predicates = Vec::with_capacity(on.len());
2300
2301        // Helpers for populating `predicates`.
2302        let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2303        let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2304        let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2305            inner_join_mapper
2306                .lookup_inputs(expr)
2307                .filter(|&i| i != Self::I_OUT)
2308                .collect()
2309        };
2310        let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2311            inner_join_mapper
2312                .lookup_inputs(expr)
2313                .any(|i| i == Self::I_SUB)
2314        };
2315
2316        // Iterate over `on` elements and populate `predicates`.
2317        for mut predicate in on {
2318            if predicate.might_error() {
2319                tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2320                // Treat predicates that can produce a literal error as Theta.
2321                predicates.push(OnPredicate::Theta(predicate));
2322            } else if has_subquery_refs(&predicate) {
2323                tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2324                // Treat predicates referencing an `on` subquery as Theta.
2325                predicates.push(OnPredicate::Theta(predicate));
2326            } else if let MirScalarExpr::CallBinary {
2327                func: Eq(_),
2328                expr1,
2329                expr2,
2330            } = &mut predicate
2331            {
2332                // Obtain the non-outer inputs referenced by each side.
2333                let inputs1 = lookup_inputs(expr1);
2334                let inputs2 = lookup_inputs(expr2);
2335
2336                match (&inputs1[..], &inputs2[..]) {
2337                    // Neither side references an input. This could be a
2338                    // constant expression or an expression that depends only on
2339                    // the outer context.
2340                    ([], []) => {
2341                        predicates.push(OnPredicate::Const(predicate));
2342                    }
2343                    // Both sides reference different inputs.
2344                    ([I_LHS], [I_RHS]) => {
2345                        let lhs = expr1.take();
2346                        let mut rhs = expr2.take();
2347                        rhs.permute(&rhs_permutation);
2348                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2349                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2350                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2351                    }
2352                    // Both sides reference different inputs (swapped).
2353                    ([I_RHS], [I_LHS]) => {
2354                        let lhs = expr2.take();
2355                        let mut rhs = expr1.take();
2356                        rhs.permute(&rhs_permutation);
2357                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2358                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2359                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2360                    }
2361                    // Both sides reference the left input or no input.
2362                    ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2363                        predicates.push(OnPredicate::Lhs(predicate));
2364                    }
2365                    // Both sides reference the right input or no input.
2366                    ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2367                        predicate.permute(&rhs_permutation);
2368                        predicates.push(OnPredicate::Rhs(predicate));
2369                    }
2370                    // At least one side references more than one input.
2371                    _ => {
2372                        tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2373                        predicates.push(OnPredicate::Theta(predicate));
2374                    }
2375                }
2376            } else {
2377                // Obtain the non-outer inputs referenced by this predicate.
2378                let inputs = lookup_inputs(&predicate);
2379
2380                match &inputs[..] {
2381                    // The predicate references no inputs. This could be a
2382                    // constant expression or an expression that depends only on
2383                    // the outer context.
2384                    [] => {
2385                        predicates.push(OnPredicate::Const(predicate));
2386                    }
2387                    // The predicate references only the left input.
2388                    [I_LHS] => {
2389                        predicates.push(OnPredicate::Lhs(predicate));
2390                    }
2391                    // The predicate references only the right input.
2392                    [I_RHS] => {
2393                        predicate.permute(&rhs_permutation);
2394                        predicates.push(OnPredicate::Rhs(predicate));
2395                    }
2396                    // The predicate references both inputs.
2397                    _ => {
2398                        tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2399                        predicates.push(OnPredicate::Theta(predicate));
2400                    }
2401                }
2402            }
2403        }
2404
2405        Self { predicates, oa }
2406    }
2407
2408    /// Check if the predicates can be lowered with an equijoin-based strategy.
2409    fn is_equijoin(&self, context: &Context) -> bool {
2410        // Count each `OnPredicate` variant in `self.predicates`.
2411        let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2412            self.predicates.iter().fold(
2413                (0, 0, 0, 0, 0, 0),
2414                |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2415                    (
2416                        const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2417                        lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2418                        rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2419                        eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2420                        eq_cols + usize::from(matches!(p, OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column())),
2421                        theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2422                    )
2423                },
2424            );
2425
2426        let is_equijion = if context.config.enable_new_outer_join_lowering {
2427            // New classifier.
2428            eq_cnt > 0 && theta_cnt == 0
2429        } else {
2430            // Old classifier.
2431            eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2432        };
2433
2434        // Log an entry only if this is an equijoin according to the new classifier.
2435        if eq_cnt > 0 && theta_cnt == 0 {
2436            tracing::debug!(
2437                const_cnt,
2438                lhs_cnt,
2439                rhs_cnt,
2440                eq_cnt,
2441                eq_cols,
2442                theta_cnt,
2443                "OnPredicates::is_equijoin"
2444            );
2445        }
2446
2447        is_equijion
2448    }
2449
2450    /// Return an [`MirRelationExpr`] list that represents the keys for the
2451    /// equijoin. The list will contain the outer columns as a prefix.
2452    fn join_keys(&self) -> JoinKeys {
2453        // We could return either the `lhs` or the `rhs` of the keys used to
2454        // form the inner join as they are equated by the join condition.
2455        let join_keys = self.eq_lhs().collect::<Vec<_>>();
2456
2457        if join_keys.iter().all(|k| k.is_column()) {
2458            tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2459            JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2460        } else {
2461            tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2462            JoinKeys::Scalars(join_keys)
2463        }
2464    }
2465
2466    /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2467    /// conditions in the predicates list.
2468    ///
2469    /// The iterator will start with column references to the outer columns as a
2470    /// prefix.
2471    fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2472        itertools::chain(
2473            (0..self.oa).map(MirScalarExpr::column),
2474            self.predicates.iter().filter_map(|e| match e {
2475                OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2476                _ => None,
2477            }),
2478        )
2479    }
2480
2481    /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2482    /// conditions in the predicates list.
2483    ///
2484    /// The iterator will start with column references to the outer columns as a
2485    /// prefix.
2486    fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2487        itertools::chain(
2488            (0..self.oa).map(MirScalarExpr::column),
2489            self.predicates.iter().filter_map(|e| match e {
2490                OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2491                _ => None,
2492            }),
2493        )
2494    }
2495
2496    /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2497    /// [`OnPredicate::Const`] conditions in the predicates list.
2498    fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2499        self.predicates.iter().filter_map(|p| match p {
2500            // We treat Const predicates local to both inputs.
2501            OnPredicate::Const(p) => Some(p.clone()),
2502            OnPredicate::Lhs(p) => Some(p.clone()),
2503            OnPredicate::LhsConsequence(p) => Some(p.clone()),
2504            _ => None,
2505        })
2506    }
2507
2508    /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2509    /// [`OnPredicate::Const`] conditions in the predicates list.
2510    fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2511        self.predicates.iter().filter_map(|p| match p {
2512            // We treat Const predicates local to both inputs.
2513            OnPredicate::Const(p) => Some(p.clone()),
2514            OnPredicate::Rhs(p) => Some(p.clone()),
2515            OnPredicate::RhsConsequence(p) => Some(p.clone()),
2516            _ => None,
2517        })
2518    }
2519}
2520
2521enum OnPredicate {
2522    // A predicate that is either constant or references only outer columns.
2523    Const(MirScalarExpr),
2524    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2525    // and possibly outer columns.
2526    //
2527    // This is one of the original predicates from the ON clause.
2528    //
2529    // One _must_ apply this predicate.
2530    Lhs(MirScalarExpr),
2531    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2532    // and possibly outer columns.
2533    //
2534    // This is not one of the original predicates from the ON clause, but is just a consequence
2535    // of an original predicate in the ON clause, where the original predicate references both
2536    // inputs, but the consequence references only the left input.
2537    //
2538    // For example, the original predicate `input1.x = input2.a` has the consequence
2539    // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2540    // prevents null skew, and also makes more CSE opportunities available when the left input's key
2541    // doesn't have a NOT NULL constraint, saving us an arrangement.
2542    //
2543    // Applying the predicate is optional, because the original predicate will be applied anyway.
2544    LhsConsequence(MirScalarExpr),
2545    // A local predicate on the right-hand side of the join.
2546    //
2547    // This is one of the original predicates from the ON clause.
2548    //
2549    // One _must_ apply this predicate.
2550    Rhs(MirScalarExpr),
2551    // A consequence of an original ON predicate, see above.
2552    RhsConsequence(MirScalarExpr),
2553    // An equality predicate between the two sides.
2554    Eq(MirScalarExpr, MirScalarExpr),
2555    // a non-equality predicate between the two sides.
2556    #[allow(dead_code)]
2557    Theta(MirScalarExpr),
2558}
2559
2560/// A set of join keys referencing an input.
2561///
2562/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2563/// avoid changes (and thereby possible regressions) in plans that have equijoin
2564/// predicates consisting only of column refs.
2565///
2566/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2567/// able to get rid of this code, but as it stands `Map` simplification seems
2568/// more cumbersome than `Project` simplification, so do this just to be sure.
2569enum JoinKeys {
2570    // A predicate that is either constant or references only outer columns.
2571    Outputs(Vec<usize>),
2572    // A local predicate on the left-hand side of the join.
2573    Scalars(Vec<MirScalarExpr>),
2574}
2575
2576impl JoinKeys {
2577    fn len(&self) -> usize {
2578        match self {
2579            JoinKeys::Outputs(outputs) => outputs.len(),
2580            JoinKeys::Scalars(scalars) => scalars.len(),
2581        }
2582    }
2583}
2584
2585/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2586/// code.
2587trait LoweringExt {
2588    /// See [`MirRelationExpr::restrict`].
2589    fn restrict(self, join_keys: JoinKeys) -> Self;
2590}
2591
2592impl LoweringExt for MirRelationExpr {
2593    /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2594    fn restrict(self, join_keys: JoinKeys) -> Self {
2595        let num_keys = join_keys.len();
2596        match join_keys {
2597            JoinKeys::Outputs(outputs) => self.project(outputs),
2598            JoinKeys::Scalars(scalars) => {
2599                let input_arity = self.arity();
2600                let outputs = (input_arity..input_arity + num_keys).collect();
2601                self.map(scalars).project(outputs)
2602            }
2603        }
2604    }
2605}