Skip to main content

mz_sql/plan/
lowering.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::func::variadic;
44use mz_expr::visit::Visit;
45use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
46use mz_ore::collections::CollectionExt;
47use mz_ore::stack::maybe_grow;
48use mz_repr::*;
49
50use crate::optimizer_metrics::OptimizerMetrics;
51use crate::plan::hir::{
52    AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
53};
54use crate::plan::{PlanError, transform_hir};
55use crate::session::vars::SystemVars;
56
57mod variadic_left;
58
59/// Maps a leveled column reference to a specific column.
60///
61/// Leveled column references are nested, so that larger levels are
62/// found early in a record and level zero is found at the end.
63///
64/// The column map only stores references for levels greater than zero,
65/// and column references at level zero simply start at the first column
66/// after all prior references.
67#[derive(Debug, Clone)]
68struct ColumnMap {
69    inner: BTreeMap<ColumnRef, usize>,
70}
71
72impl ColumnMap {
73    fn empty() -> ColumnMap {
74        Self::new(BTreeMap::new())
75    }
76
77    fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
78        ColumnMap { inner }
79    }
80
81    fn get(&self, col_ref: &ColumnRef) -> usize {
82        if col_ref.level == 0 {
83            self.inner.len() + col_ref.column
84        } else {
85            self.inner[col_ref]
86        }
87    }
88
89    fn len(&self) -> usize {
90        self.inner.len()
91    }
92
93    /// Updates references in the `ColumnMap` for use in a nested scope. The
94    /// provided `arity` must specify the arity of the current scope.
95    fn enter_scope(&self, arity: usize) -> ColumnMap {
96        // From the perspective of the nested scope, all existing column
97        // references will be one level greater.
98        let existing = self
99            .inner
100            .clone()
101            .into_iter()
102            .update(|(col, _i)| col.level += 1);
103
104        // All columns in the current scope become explicit entries in the
105        // immediate parent scope.
106        let new = (0..arity).map(|i| {
107            (
108                ColumnRef {
109                    level: 1,
110                    column: i,
111                },
112                self.len() + i,
113            )
114        });
115
116        ColumnMap::new(existing.chain(new).collect())
117    }
118}
119
120/// Map with the CTEs currently in scope.
121type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
122
123/// Information about needed when finding a reference to a CTE in scope.
124#[derive(Clone)]
125struct CteDesc {
126    /// The new ID assigned to the lowered version of the CTE, which may not match
127    /// the ID of the input CTE.
128    new_id: mz_expr::LocalId,
129    /// The relation type of the CTE including the columns from the outer
130    /// context at the beginning.
131    relation_type: ReprRelationType,
132    /// The outer relation the CTE was applied to.
133    outer_relation: MirRelationExpr,
134}
135
136#[derive(Debug, Clone, Copy)]
137pub struct Config {
138    /// Enable outer join lowering implemented in database-issues#6747.
139    pub enable_new_outer_join_lowering: bool,
140    /// Enable outer join lowering implemented in database-issues#7561.
141    pub enable_variadic_left_join_lowering: bool,
142    pub enable_cast_elimination: bool,
143    pub enable_simplify_quantified_comparisons: bool,
144}
145
146impl Default for Config {
147    fn default() -> Self {
148        Self {
149            enable_new_outer_join_lowering: false,
150            enable_variadic_left_join_lowering: false,
151            enable_cast_elimination: false,
152            enable_simplify_quantified_comparisons: false,
153        }
154    }
155}
156
157impl From<&SystemVars> for Config {
158    fn from(vars: &SystemVars) -> Self {
159        Self {
160            enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
161            enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
162            enable_cast_elimination: vars.enable_cast_elimination(),
163            enable_simplify_quantified_comparisons: vars.enable_simplify_quantified_comparisons(),
164        }
165    }
166}
167
168/// Context passed to the lowering. This is wired to most parts of the lowering.
169pub(crate) struct Context<'a> {
170    /// Feature flags affecting the behavior of lowering.
171    pub config: &'a Config,
172    /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
173    /// simply don't write metrics.
174    pub metrics: Option<&'a OptimizerMetrics>,
175}
176
177impl HirRelationExpr {
178    /// Rewrite `self` into a `MirRelationExpr`.
179    /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
180    #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
181    pub fn lower<C: Into<Config>>(
182        self,
183        config: C,
184        metrics: Option<&OptimizerMetrics>,
185    ) -> Result<MirRelationExpr, PlanError> {
186        let context = Context {
187            config: &config.into(),
188            metrics,
189        };
190        let result = match self {
191            // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
192            // to ensure that the downstream optimizer can easily bypass most
193            // irrelevant optimizations (e.g. reduce folding) for this expression
194            // without having to re-learn the fact that it is just a constant,
195            // as it would if the constant were wrapped in a Let-Get pair.
196            HirRelationExpr::Constant { rows, typ } => {
197                let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
198                MirRelationExpr::Constant {
199                    rows: Ok(rows),
200                    typ: ReprRelationType::from(&typ),
201                }
202            }
203            mut other => {
204                let mut id_gen = mz_ore::id_gen::IdGen::default();
205                transform_hir::split_subquery_predicates(&mut other)?;
206                transform_hir::try_simplify_quantified_comparisons(
207                    &mut other,
208                    context.config.enable_simplify_quantified_comparisons,
209                )?;
210                transform_hir::fuse_window_functions(&mut other, &context)?;
211                MirRelationExpr::constant(vec![vec![]], ReprRelationType::new(vec![])).let_in(
212                    &mut id_gen,
213                    |id_gen, get_outer| {
214                        other.applied_to(
215                            id_gen,
216                            get_outer,
217                            &ColumnMap::empty(),
218                            &mut CteMap::new(),
219                            &context,
220                        )
221                    },
222                )?
223            }
224        };
225
226        mz_repr::explain::trace_plan(&result);
227
228        Ok(result)
229    }
230
231    /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
232    ///
233    /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
234    /// When `self` references columns of `get_outer`, much more work needs to occur.
235    ///
236    /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
237    /// perhaps not all of them. It should be used as the basis of resolving column references,
238    /// but care must be taken when adding new columns that `get_outer.arity()` is where they
239    /// will start, rather than any function of `col_map`.
240    ///
241    /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
242    /// assignment of values to outer rows.
243    fn applied_to(
244        self,
245        id_gen: &mut mz_ore::id_gen::IdGen,
246        get_outer: MirRelationExpr,
247        col_map: &ColumnMap,
248        cte_map: &mut CteMap,
249        context: &Context,
250    ) -> Result<MirRelationExpr, PlanError> {
251        maybe_grow(|| {
252            use MirRelationExpr as SR;
253
254            use HirRelationExpr::*;
255
256            if let MirRelationExpr::Get { .. } = &get_outer {
257            } else {
258                panic!(
259                    "get_outer: expected a MirRelationExpr::Get, found\n{}",
260                    get_outer.pretty(),
261                );
262            }
263            assert_eq!(col_map.len(), get_outer.arity());
264            Ok(match self {
265                Constant { rows, typ } => {
266                    // Constant expressions are not correlated with `get_outer`, and should be cross-products.
267                    get_outer.product(SR::Constant {
268                        rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
269                        typ: ReprRelationType::from(&typ),
270                    })
271                }
272                Get { id, typ } => match id {
273                    mz_expr::Id::Local(local_id) => {
274                        let cte_desc = cte_map.get(&local_id).unwrap();
275                        let get_cte = SR::Get {
276                            id: mz_expr::Id::Local(cte_desc.new_id.clone()),
277                            typ: cte_desc.relation_type.clone(),
278                            access_strategy: AccessStrategy::UnknownOrLocal,
279                        };
280                        if get_outer == cte_desc.outer_relation {
281                            // If the CTE was applied to the same exact relation, we can safely
282                            // return a `Get` relation.
283                            get_cte
284                        } else {
285                            // Otherwise, the new outer relation may contain more columns from some
286                            // intermediate scope placed between the definition of the CTE and this
287                            // reference of the CTE and/or more operations applied on top of the
288                            // outer relation.
289                            //
290                            // An example of the latter is the following query:
291                            //
292                            // SELECT *
293                            // FROM x,
294                            //      LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
295                            //              SELECT (SELECT m FROM a) FROM y) b;
296                            //
297                            // When the CTE is lowered, the outer relation is `Get x`. But then,
298                            // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
299                            // which has the same cardinality as `Get x`.
300                            //
301                            // In any case, `get_outer` is guaranteed to contain the columns of the
302                            // outer relation the CTE was applied to at its prefix. Since, we must
303                            // return a relation containing `get_outer`'s column at the beginning,
304                            // we must build a join between `get_outer` and `get_cte` on their common
305                            // columns.
306                            let oa = get_outer.arity();
307                            let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
308                            let equivalences = (0..cte_outer_columns)
309                                .map(|pos| {
310                                    vec![
311                                        MirScalarExpr::column(pos),
312                                        MirScalarExpr::column(pos + oa),
313                                    ]
314                                })
315                                .collect();
316
317                            // Project out the second copy of the common between `get_outer` and
318                            // `cte_desc.outer_relation`.
319                            let projection = (0..oa)
320                                .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
321                                .collect_vec();
322                            SR::join_scalars(vec![get_outer, get_cte], equivalences)
323                                .project(projection)
324                        }
325                    }
326                    mz_expr::Id::Global(_) => {
327                        // Get statements are only to external sources, and are not correlated with `get_outer`.
328                        get_outer.product(SR::Get {
329                            id,
330                            typ: ReprRelationType::from(&typ),
331                            access_strategy: AccessStrategy::UnknownOrLocal,
332                        })
333                    }
334                },
335                Let {
336                    name: _,
337                    id,
338                    value,
339                    body,
340                } => {
341                    let value =
342                        value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
343                    value.let_in(id_gen, |id_gen, get_value| {
344                        let (new_id, typ) = if let MirRelationExpr::Get {
345                            id: mz_expr::Id::Local(id),
346                            typ,
347                            ..
348                        } = get_value
349                        {
350                            (id, typ)
351                        } else {
352                            panic!(
353                                "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
354                                get_value.pretty(),
355                            );
356                        };
357                        // Add the information about the CTE to the map and remove it when
358                        // it goes out of scope.
359                        let old_value = cte_map.insert(
360                            id.clone(),
361                            CteDesc {
362                                new_id,
363                                relation_type: typ,
364                                outer_relation: get_outer.clone(),
365                            },
366                        );
367                        let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
368                        if let Some(old_value) = old_value {
369                            cte_map.insert(id, old_value);
370                        } else {
371                            cte_map.remove(&id);
372                        }
373                        body
374                    })?
375                }
376                LetRec {
377                    limit,
378                    bindings,
379                    body,
380                } => {
381                    let num_bindings = bindings.len();
382
383                    // We use the outer type with the HIR types to form MIR CTE types.
384                    let outer_column_types = get_outer.typ().column_types;
385
386                    // Rename and introduce all bindings.
387                    let mut shadowed_bindings = Vec::with_capacity(num_bindings);
388                    let mut mir_ids = Vec::with_capacity(num_bindings);
389                    for (_name, id, _value, typ) in bindings.iter() {
390                        let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
391                        mir_ids.push(mir_id);
392                        let shadowed = cte_map.insert(
393                            id.clone(),
394                            CteDesc {
395                                new_id: mir_id,
396                                relation_type: ReprRelationType::new(
397                                    outer_column_types
398                                        .iter()
399                                        .cloned()
400                                        .chain(typ.column_types.iter().map(ReprColumnType::from))
401                                        .collect::<Vec<_>>(),
402                                ),
403                                outer_relation: get_outer.clone(),
404                            },
405                        );
406                        shadowed_bindings.push((*id, shadowed));
407                    }
408
409                    let mut mir_values = Vec::with_capacity(num_bindings);
410                    for (_name, _id, value, _typ) in bindings.into_iter() {
411                        mir_values.push(value.applied_to(
412                            id_gen,
413                            get_outer.clone(),
414                            col_map,
415                            cte_map,
416                            context,
417                        )?);
418                    }
419
420                    let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
421
422                    // Remove our bindings and reinstate any shadowed bindings.
423                    for (id, shadowed) in shadowed_bindings {
424                        if let Some(shadowed) = shadowed {
425                            cte_map.insert(id, shadowed);
426                        } else {
427                            cte_map.remove(&id);
428                        }
429                    }
430
431                    MirRelationExpr::LetRec {
432                        ids: mir_ids,
433                        values: mir_values,
434                        // Copy the limit to each binding.
435                        limits: repeat(limit).take(num_bindings).collect(),
436                        body: Box::new(mir_body),
437                    }
438                }
439                Project { input, outputs } => {
440                    // Projections should be applied to the decorrelated `inner`, and to its columns,
441                    // which means rebasing `outputs` to start `get_outer.arity()` columns later.
442                    let input =
443                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
444                    let outputs = (0..get_outer.arity())
445                        .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
446                        .collect::<Vec<_>>();
447                    input.project(outputs)
448                }
449                Map { input, mut scalars } => {
450                    // Scalar expressions may contain correlated subqueries. We must be cautious!
451
452                    // We lower scalars in chunks, and must keep track of the
453                    // arity of the HIR fragments lowered so far.
454                    let mut lowered_arity = input.arity();
455
456                    let mut input =
457                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
458
459                    // Lower subqueries in maximally sized batches, such as no subquery in the current
460                    // batch depends on columns from the same batch.
461                    // Note that subqueries in this projection may reference columns added by this
462                    // Map operator, so we need to ensure these columns exist before lowering the
463                    // subquery.
464                    while !scalars.is_empty() {
465                        let end_idx = scalars
466                            .iter_mut()
467                            .position(|s| {
468                                let mut requires_nonexistent_column = false;
469                                #[allow(deprecated)]
470                                s.visit_columns(0, &mut |depth, col| {
471                                    if col.level == depth {
472                                        requires_nonexistent_column |= col.column >= lowered_arity
473                                    }
474                                });
475                                requires_nonexistent_column
476                            })
477                            .unwrap_or(scalars.len());
478                        assert!(
479                            end_idx > 0,
480                            "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
481                            lowered_arity,
482                            scalars
483                        );
484
485                        lowered_arity = lowered_arity + end_idx;
486                        let scalars = scalars.drain(0..end_idx).collect_vec();
487
488                        let old_arity = input.arity();
489                        let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
490                            &scalars, id_gen, col_map, cte_map, input, context,
491                        )?;
492                        input = with_subqueries;
493
494                        // We will proceed sequentially through the scalar expressions, for each transforming
495                        // the decorrelated `input` into a relation with potentially more columns capable of
496                        // addressing the needs of the scalar expression.
497                        // Having done so, we add the scalar value of interest and trim off any other newly
498                        // added columns.
499                        //
500                        // The sequential traversal is present as expressions are allowed to depend on the
501                        // values of prior expressions.
502                        let mut scalar_columns = Vec::new();
503                        for scalar in scalars {
504                            let scalar = scalar.applied_to(
505                                id_gen,
506                                col_map,
507                                cte_map,
508                                &mut input,
509                                &Some(&subquery_map),
510                                context,
511                            )?;
512                            input = input.map_one(scalar);
513                            scalar_columns.push(input.arity() - 1);
514                        }
515
516                        // Discard any new columns added by the lowering of the scalar expressions
517                        input = input.project((0..old_arity).chain(scalar_columns).collect());
518                    }
519
520                    input
521                }
522                CallTable { func, exprs } => {
523                    // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
524                    // allowed to refer to the results of previous expressions, and we have a simpler
525                    // implementation that appends all relevant columns first, then applies the flatmap
526                    // operator to the result, then strips off any columns introduce by subqueries.
527
528                    let mut input = get_outer;
529                    let old_arity = input.arity();
530
531                    let exprs = exprs
532                        .into_iter()
533                        .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
534                        .collect::<Result<Vec<_>, _>>()?;
535
536                    let new_arity = input.arity();
537                    let output_arity = func.output_arity();
538                    input = input.flat_map(func, exprs);
539                    if old_arity != new_arity {
540                        // this means we added some columns to handle subqueries, and now we need to get rid of them
541                        input = input.project(
542                            (0..old_arity)
543                                .chain(new_arity..new_arity + output_arity)
544                                .collect(),
545                        );
546                    }
547                    input
548                }
549                Filter { input, predicates } => {
550                    // Filter expressions may contain correlated subqueries.
551                    // We extend `get_outer` with sufficient values to determine the value of the predicate,
552                    // then filter the results, then strip off any columns that were added for this purpose.
553                    let mut input =
554                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
555                    for predicate in predicates {
556                        let old_arity = input.arity();
557                        let predicate = predicate
558                            .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
559                        let new_arity = input.arity();
560                        input = input.filter(vec![predicate]);
561                        if old_arity != new_arity {
562                            // this means we added some columns to handle subqueries, and now we need to get rid of them
563                            input = input.project((0..old_arity).collect());
564                        }
565                    }
566                    input
567                }
568                Join {
569                    left,
570                    right,
571                    on,
572                    kind,
573                } if right.is_correlated() => {
574                    // A correlated join is a join in which the right expression has
575                    // access to the columns in the left expression. It turns out
576                    // this is *exactly* our branch operator, plus some additional
577                    // null handling in the case of left joins. (Right and full
578                    // lateral joins are not permitted.)
579                    //
580                    // As with normal joins, the `on` predicate may be correlated,
581                    // and we treat it as a filter that follows the branch.
582
583                    assert!(kind.can_be_correlated());
584
585                    let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
586                    left.let_in(id_gen, |id_gen, get_left| {
587                        let apply_requires_distinct_outer = false;
588                        let mut join = branch(
589                            id_gen,
590                            get_left.clone(),
591                            col_map,
592                            cte_map,
593                            *right,
594                            apply_requires_distinct_outer,
595                            context,
596                            |id_gen, right, get_left, col_map, cte_map, context| {
597                                right.applied_to(id_gen, get_left, col_map, cte_map, context)
598                            },
599                        )?;
600
601                        // Plan the `on` predicate.
602                        let old_arity = join.arity();
603                        let on =
604                            on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
605                        join = join.filter(vec![on]);
606                        let new_arity = join.arity();
607                        if old_arity != new_arity {
608                            // This means we added some columns to handle
609                            // subqueries, and now we need to get rid of them.
610                            join = join.project((0..old_arity).collect());
611                        }
612
613                        // If a left join, reintroduce any rows from the left that
614                        // are missing, with nulls filled in for the right columns.
615                        if let JoinKind::LeftOuter { .. } = kind {
616                            let default = join
617                                .typ()
618                                .column_types
619                                .into_iter()
620                                .skip(get_left.arity())
621                                .map(|typ| (Datum::Null, typ.scalar_type))
622                                .collect();
623                            get_left.lookup(id_gen, join, default)
624                        } else {
625                            Ok::<_, PlanError>(join)
626                        }
627                    })?
628                }
629                Join {
630                    left,
631                    right,
632                    on,
633                    kind,
634                } => {
635                    if context.config.enable_variadic_left_join_lowering {
636                        // Attempt to extract a stack of left joins.
637                        if let JoinKind::LeftOuter = kind {
638                            let mut rights = vec![(&*right, &on)];
639                            let mut left_test = &left;
640                            while let Join {
641                                left,
642                                right,
643                                on,
644                                kind: JoinKind::LeftOuter,
645                            } = &**left_test
646                            {
647                                rights.push((&**right, on));
648                                left_test = left;
649                            }
650                            if rights.len() > 1 {
651                                // Defensively clone `cte_map` as it may be mutated.
652                                let cte_map_clone = cte_map.clone();
653                                if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
654                                    left_test,
655                                    rights,
656                                    id_gen,
657                                    get_outer.clone(),
658                                    col_map,
659                                    cte_map,
660                                    context,
661                                ) {
662                                    return Ok(magic);
663                                } else {
664                                    cte_map.clone_from(&cte_map_clone);
665                                }
666                            }
667                        }
668                    }
669
670                    // Both join expressions should be decorrelated, and then joined by their
671                    // leading columns to form only those pairs corresponding to the same row
672                    // of `get_outer`.
673                    //
674                    // The `on` predicate may contain correlated subqueries, and we treat it
675                    // as though it was a filter, with the caveat that we also translate outer
676                    // joins in this step. The post-filtration results need to be considered
677                    // against the records present in the left and right (decorrelated) inputs,
678                    // depending on the type of join.
679                    let oa = get_outer.arity();
680                    let left =
681                        left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
682                    let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
683                    let la = lt.len();
684                    left.let_in(id_gen, |id_gen, get_left| {
685                        let right_col_map = col_map.enter_scope(0);
686                        let right = right.applied_to(
687                            id_gen,
688                            get_outer.clone(),
689                            &right_col_map,
690                            cte_map,
691                            context,
692                        )?;
693                        let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
694                        let ra = rt.len();
695                        right.let_in(id_gen, |id_gen, get_right| {
696                            let mut product = SR::join(
697                                vec![get_left.clone(), get_right.clone()],
698                                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
699                            )
700                            // Project away the repeated copy of get_outer's columns.
701                            .project(
702                                (0..(oa + la))
703                                    .chain((oa + la + oa)..(oa + la + oa + ra))
704                                    .collect(),
705                            );
706
707                            // Decorrelate and lower the `on` clause.
708                            let on = on.applied_to(
709                                id_gen,
710                                col_map,
711                                cte_map,
712                                &mut product,
713                                &None,
714                                context,
715                            )?;
716                            // Collect the types of all subqueries appearing in
717                            // the `on` clause. The subquery results were
718                            // appended to `product` in the `on.applied_to(...)`
719                            // call above.
720                            let on_subquery_types = product
721                                .typ()
722                                .column_types
723                                .drain(oa + la + ra..)
724                                .collect_vec();
725                            // Remember if `on` had any subqueries.
726                            let on_has_subqueries = !on_subquery_types.is_empty();
727
728                            // Attempt an efficient equijoin implementation, in which outer joins are
729                            // more efficiently rendered than in general. This can return `None` if
730                            // such a plan is not possible, for example if `on` does not describe an
731                            // equijoin between columns of `left` and `right`.
732                            if kind != JoinKind::Inner {
733                                if let Some(joined) = attempt_outer_equijoin(
734                                    get_left.clone(),
735                                    get_right.clone(),
736                                    on.clone(),
737                                    on_subquery_types,
738                                    kind.clone(),
739                                    oa,
740                                    id_gen,
741                                    context,
742                                )? {
743                                    if let Some(metrics) = context.metrics {
744                                        metrics.inc_outer_join_lowering("equi");
745                                    }
746                                    return Ok(joined);
747                                }
748                            }
749
750                            // Otherwise, perform a more general join.
751                            if let Some(metrics) = context.metrics {
752                                metrics.inc_outer_join_lowering("general");
753                            }
754                            let mut join = product.filter(vec![on]);
755                            if on_has_subqueries {
756                                // This means that `on.applied_to(...)` appended
757                                // some columns to handle subqueries, and now we
758                                // need to get rid of them.
759                                join = join.project((0..oa + la + ra).collect());
760                            }
761                            join.let_in(id_gen, |id_gen, get_join| {
762                                let mut result = get_join.clone();
763                                if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
764                                    kind
765                                {
766                                    let left_outer = get_left.clone().anti_lookup::<PlanError>(
767                                        id_gen,
768                                        get_join.clone(),
769                                        rt.into_iter()
770                                            .map(|typ| (Datum::Null, typ.scalar_type))
771                                            .collect(),
772                                    )?;
773                                    result = result.union(left_outer);
774                                }
775                                if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
776                                    let right_outer = get_right
777                                        .clone()
778                                        .anti_lookup::<PlanError>(
779                                            id_gen,
780                                            get_join
781                                                // need to swap left and right to make the anti_lookup work
782                                                .project(
783                                                    (0..oa)
784                                                        .chain((oa + la)..(oa + la + ra))
785                                                        .chain((oa)..(oa + la))
786                                                        .collect(),
787                                                ),
788                                            lt.into_iter()
789                                                .map(|typ| (Datum::Null, typ.scalar_type))
790                                                .collect(),
791                                        )?
792                                        // swap left and right back again
793                                        .project(
794                                            (0..oa)
795                                                .chain((oa + ra)..(oa + ra + la))
796                                                .chain((oa)..(oa + ra))
797                                                .collect(),
798                                        );
799                                    result = result.union(right_outer);
800                                }
801                                Ok::<MirRelationExpr, PlanError>(result)
802                            })
803                        })
804                    })?
805                }
806                Union { base, inputs } => {
807                    // Union is uncomplicated.
808                    SR::Union {
809                        base: Box::new(base.applied_to(
810                            id_gen,
811                            get_outer.clone(),
812                            col_map,
813                            cte_map,
814                            context,
815                        )?),
816                        inputs: inputs
817                            .into_iter()
818                            .map(|input| {
819                                input.applied_to(
820                                    id_gen,
821                                    get_outer.clone(),
822                                    col_map,
823                                    cte_map,
824                                    context,
825                                )
826                            })
827                            .collect::<Result<Vec<_>, _>>()?,
828                    }
829                }
830                Reduce {
831                    input,
832                    group_key,
833                    aggregates,
834                    expected_group_size,
835                } => {
836                    // Reduce may contain expressions with correlated subqueries.
837                    // In addition, here an empty reduction key signifies that we need to supply default values
838                    // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
839                    let mut input =
840                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
841                    let applied_group_key = (0..get_outer.arity())
842                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
843                        .collect();
844                    let applied_aggregates = aggregates
845                        .into_iter()
846                        .map(|aggregate| {
847                            aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
848                        })
849                        .collect::<Result<Vec<_>, _>>()?;
850                    let input_type = input.typ();
851                    let default = applied_aggregates
852                        .iter()
853                        .map(|agg| {
854                            (
855                                agg.func.default(),
856                                agg.typ(&input_type.column_types).scalar_type,
857                            )
858                        })
859                        .collect();
860                    // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
861                    let mut reduced =
862                        input.reduce(applied_group_key, applied_aggregates, expected_group_size);
863
864                    // Introduce default values in the case the group key is empty.
865                    if group_key.is_empty() {
866                        reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
867                    }
868                    reduced
869                }
870                Distinct { input } => {
871                    // Distinct is uncomplicated.
872                    input
873                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
874                        .distinct()
875                }
876                TopK {
877                    input,
878                    group_key,
879                    order_key,
880                    limit,
881                    offset,
882                    expected_group_size,
883                } => {
884                    // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
885                    let mut input =
886                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
887                    let mut applied_group_key: Vec<_> = (0..get_outer.arity())
888                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
889                        .collect();
890                    let applied_order_key = order_key
891                        .iter()
892                        .map(|column_order| ColumnOrder {
893                            column: column_order.column + get_outer.arity(),
894                            desc: column_order.desc,
895                            nulls_last: column_order.nulls_last,
896                        })
897                        .collect();
898
899                    let old_arity = input.arity();
900
901                    // Lower `limit`, which may introduce new columns if is a correlated subquery.
902                    let mut limit_mir = None;
903                    if let Some(limit) = limit {
904                        limit_mir = Some(
905                            limit
906                                .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
907                        );
908                    }
909
910                    let new_arity = input.arity();
911                    // Extend the key to contain any new columns.
912                    applied_group_key.extend(old_arity..new_arity);
913
914                    let offset = offset
915                        .try_into_literal_int64()
916                        .expect("Should be a Literal by this time")
917                        .try_into()
918                        .expect("Should have checked non-negativity of OFFSET clause already");
919                    let mut result = input.top_k(
920                        applied_group_key,
921                        applied_order_key,
922                        limit_mir,
923                        offset,
924                        expected_group_size,
925                    );
926
927                    // If new columns were added for `limit` we must remove them.
928                    if old_arity != new_arity {
929                        result = result.project((0..old_arity).collect());
930                    }
931
932                    result
933                }
934                Negate { input } => {
935                    // Negate is uncomplicated.
936                    input
937                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
938                        .negate()
939                }
940                Threshold { input } => {
941                    // Threshold is uncomplicated.
942                    input
943                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
944                        .threshold()
945                }
946            })
947        })
948    }
949}
950
951impl HirScalarExpr {
952    /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
953    ///
954    /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
955    /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
956    /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
957    /// documented in more detail closer to their logic.
958    ///
959    /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
960    /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
961    /// each of these references can be found.
962    fn applied_to(
963        self,
964        id_gen: &mut mz_ore::id_gen::IdGen,
965        col_map: &ColumnMap,
966        cte_map: &mut CteMap,
967        inner: &mut MirRelationExpr,
968        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
969        context: &Context,
970    ) -> Result<MirScalarExpr, PlanError> {
971        maybe_grow(|| {
972            use MirScalarExpr as SS;
973
974            use HirScalarExpr::*;
975
976            if let Some(subquery_map) = subquery_map {
977                if let Some(col) = subquery_map.get(&self) {
978                    return Ok(SS::column(*col));
979                }
980            }
981
982            Ok::<MirScalarExpr, PlanError>(match self {
983                Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
984                Literal(row, typ, _name) => SS::Literal(Ok(row), ReprColumnType::from(&typ)),
985                Parameter(_, _name) => {
986                    panic!("cannot decorrelate expression with unbound parameters")
987                }
988                CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
989                CallUnary {
990                    func,
991                    expr,
992                    name: _,
993                } => {
994                    let inner =
995                        expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
996                    if context.config.enable_cast_elimination && func.is_eliminable_cast() {
997                        inner
998                    } else {
999                        SS::CallUnary {
1000                            func,
1001                            expr: Box::new(inner),
1002                        }
1003                    }
1004                }
1005                CallBinary {
1006                    func,
1007                    expr1,
1008                    expr2,
1009                    name: _,
1010                } => SS::CallBinary {
1011                    func,
1012                    expr1: Box::new(expr1.applied_to(
1013                        id_gen,
1014                        col_map,
1015                        cte_map,
1016                        inner,
1017                        subquery_map,
1018                        context,
1019                    )?),
1020                    expr2: Box::new(expr2.applied_to(
1021                        id_gen,
1022                        col_map,
1023                        cte_map,
1024                        inner,
1025                        subquery_map,
1026                        context,
1027                    )?),
1028                },
1029                CallVariadic {
1030                    func,
1031                    exprs,
1032                    name: _,
1033                } => SS::call_variadic(
1034                    func,
1035                    exprs
1036                        .into_iter()
1037                        .map(|expr| {
1038                            expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1039                        })
1040                        .collect::<Result<_, _>>()?,
1041                ),
1042                If {
1043                    cond,
1044                    then,
1045                    els,
1046                    name,
1047                } => {
1048                    // The `If` case is complicated by the fact that we do not want to
1049                    // apply the `then` or `else` logic to tuples that respectively do
1050                    // not or do pass the `cond` test. Our strategy is to independently
1051                    // decorrelate the `then` and `else` logic, and apply each to tuples
1052                    // that respectively pass and do not pass the `cond` logic (which is
1053                    // executed, and so decorrelated, for all tuples).
1054                    //
1055                    // Informally, we turn the `if` statement into:
1056                    //
1057                    //   let then_case = inner.filter(cond).map(then);
1058                    //   let else_case = inner.filter(!cond).map(else);
1059                    //   return then_case.concat(else_case);
1060                    //
1061                    // We only require this if either expression would result in any
1062                    // computation beyond the expr itself, which we will interpret as
1063                    // "introduces additional columns". In the absence of correlation,
1064                    // we should just retain a `ScalarExpr::If` expression; the inverse
1065                    // transformation as above is complicated to recover after the fact,
1066                    // and we would benefit from not introducing the complexity.
1067
1068                    let inner_arity = inner.arity();
1069                    let cond_expr =
1070                        cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1071
1072                    // Defensive copies, in case we mangle these in decorrelation.
1073                    let inner_clone = inner.clone();
1074                    let then_clone = then.clone();
1075                    let else_clone = els.clone();
1076
1077                    let cond_arity = inner.arity();
1078                    let then_expr =
1079                        then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1080                    let else_expr =
1081                        els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1082
1083                    if cond_arity == inner.arity() {
1084                        // If no additional columns were added, we simply return the
1085                        // `If` variant with the updated expressions.
1086                        SS::If {
1087                            cond: Box::new(cond_expr),
1088                            then: Box::new(then_expr),
1089                            els: Box::new(else_expr),
1090                        }
1091                    } else {
1092                        // If columns were added, we need a more careful approach, as
1093                        // described above. First, we need to de-correlate each of
1094                        // the two expressions independently, and apply their cases
1095                        // as `MirRelationExpr::Map` operations.
1096
1097                        *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1098                            // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1099                            let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1100                            let then_expr = then_clone.applied_to(
1101                                id_gen,
1102                                col_map,
1103                                cte_map,
1104                                &mut then_inner,
1105                                subquery_map,
1106                                context,
1107                            )?;
1108                            let then_arity = then_inner.arity();
1109                            then_inner = then_inner
1110                                .map_one(then_expr)
1111                                .project((0..inner_arity).chain(Some(then_arity)).collect());
1112
1113                            // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1114                            let mut else_inner = get_inner.filter(vec![SS::call_variadic(
1115                                variadic::Or,
1116                                vec![
1117                                    cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1118                                    cond_expr.clone().call_is_null(),
1119                                ],
1120                            )]);
1121                            let else_expr = else_clone.applied_to(
1122                                id_gen,
1123                                col_map,
1124                                cte_map,
1125                                &mut else_inner,
1126                                subquery_map,
1127                                context,
1128                            )?;
1129                            let else_arity = else_inner.arity();
1130                            else_inner = else_inner
1131                                .map_one(else_expr)
1132                                .project((0..inner_arity).chain(Some(else_arity)).collect());
1133
1134                            // concatenate the two results.
1135                            Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1136                        })?;
1137
1138                        SS::Column(inner_arity, name)
1139                    }
1140                }
1141
1142                // Subqueries!
1143                // These are surprisingly subtle. Things to be careful of:
1144
1145                // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1146                // * change the row counts of the outer query
1147                // * accidentally compute its own value using the row counts of the outer query
1148                // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1149                // query and then join the answers back on to the original rows of the outer query.
1150
1151                // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1152                // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1153                Exists(expr, name) => {
1154                    let apply_requires_distinct_outer = true;
1155                    *inner = apply_existential_subquery(
1156                        id_gen,
1157                        inner.take_dangerous(),
1158                        col_map,
1159                        cte_map,
1160                        *expr,
1161                        apply_requires_distinct_outer,
1162                        context,
1163                    )?;
1164                    SS::Column(inner.arity() - 1, name)
1165                }
1166
1167                Select(expr, name) => {
1168                    let apply_requires_distinct_outer = true;
1169                    *inner = apply_scalar_subquery(
1170                        id_gen,
1171                        inner.take_dangerous(),
1172                        col_map,
1173                        cte_map,
1174                        *expr,
1175                        apply_requires_distinct_outer,
1176                        context,
1177                    )?;
1178                    SS::Column(inner.arity() - 1, name)
1179                }
1180                Windowing(expr, _name) => {
1181                    let partition_by = expr.partition_by;
1182                    let order_by = expr.order_by;
1183
1184                    // argument lowering for scalar window functions
1185                    // (We need to specify the & _ in the arguments because of this problem:
1186                    // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1187                    let scalar_lower_args =
1188                        |_id_gen: &mut _,
1189                         _col_map: &_,
1190                         _cte_map: &mut _,
1191                         _get_inner: &mut _,
1192                         _subquery_map: &Option<&_>,
1193                         order_by_mir: Vec<MirScalarExpr>,
1194                         original_row_record,
1195                         original_row_record_type: SqlScalarType| {
1196                            let agg_input = MirScalarExpr::call_variadic(
1197                                variadic::ListCreate {
1198                                    elem_type: original_row_record_type.clone(),
1199                                },
1200                                vec![original_row_record],
1201                            );
1202                            let mut agg_input = vec![agg_input];
1203                            agg_input.extend(order_by_mir.clone());
1204                            let agg_input = MirScalarExpr::call_variadic(
1205                                variadic::RecordCreate {
1206                                    field_names: (0..agg_input.len())
1207                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1208                                        .collect_vec(),
1209                                },
1210                                agg_input,
1211                            );
1212                            let list_type = SqlScalarType::List {
1213                                element_type: Box::new(original_row_record_type),
1214                                custom_id: None,
1215                            };
1216                            let agg_input_type = SqlScalarType::Record {
1217                                fields: std::iter::once(&list_type)
1218                                    .map(|t| {
1219                                        (
1220                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1221                                            t.clone().nullable(false),
1222                                        )
1223                                    })
1224                                    .collect(),
1225                                custom_id: None,
1226                            }
1227                            .nullable(false);
1228
1229                            Ok((agg_input, agg_input_type))
1230                        };
1231
1232                    // argument lowering for value window functions and aggregate window functions
1233                    let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1234                        |id_gen: &mut _,
1235                         col_map: &_,
1236                         cte_map: &mut _,
1237                         get_inner: &mut _,
1238                         subquery_map: &Option<&_>,
1239                         order_by_mir: Vec<MirScalarExpr>,
1240                         original_row_record,
1241                         original_row_record_type| {
1242                            // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1243
1244                            // Compute the encoded args for all rows
1245                            let mir_encoded_args = hir_encoded_args.applied_to(
1246                                id_gen,
1247                                col_map,
1248                                cte_map,
1249                                get_inner,
1250                                subquery_map,
1251                                context,
1252                            )?;
1253                            let mir_encoded_args_type = mir_encoded_args
1254                                .sql_typ(&get_inner.sql_typ().column_types)
1255                                .scalar_type;
1256
1257                            // Build a new record that has two fields:
1258                            // 1. the original row in a record
1259                            // 2. the encoded args (which can be either a single value, or a record
1260                            //    if the window function has multiple arguments, such as `lag`)
1261                            let fn_input_record_fields: Box<[_]> =
1262                                [original_row_record_type, mir_encoded_args_type]
1263                                    .iter()
1264                                    .map(|t| {
1265                                        (
1266                                            ColumnName::from(UNKNOWN_COLUMN_NAME),
1267                                            t.clone().nullable(false),
1268                                        )
1269                                    })
1270                                    .collect();
1271                            let fn_input_record = MirScalarExpr::call_variadic(
1272                                variadic::RecordCreate {
1273                                    field_names: fn_input_record_fields
1274                                        .iter()
1275                                        .map(|(n, _)| n.clone())
1276                                        .collect_vec(),
1277                                },
1278                                vec![original_row_record, mir_encoded_args],
1279                            );
1280                            let fn_input_record_type = SqlScalarType::Record {
1281                                fields: fn_input_record_fields,
1282                                custom_id: None,
1283                            }
1284                            .nullable(false);
1285
1286                            // Build a new record with the record above + the ORDER BY exprs
1287                            // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1288                            let mut agg_input = vec![fn_input_record];
1289                            agg_input.extend(order_by_mir.clone());
1290                            let agg_input = MirScalarExpr::call_variadic(
1291                                variadic::RecordCreate {
1292                                    field_names: (0..agg_input.len())
1293                                        .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1294                                        .collect_vec(),
1295                                },
1296                                agg_input,
1297                            );
1298
1299                            let agg_input_type = SqlScalarType::Record {
1300                                fields: [(
1301                                    ColumnName::from(UNKNOWN_COLUMN_NAME),
1302                                    fn_input_record_type.nullable(false),
1303                                )]
1304                                .into(),
1305                                custom_id: None,
1306                            }
1307                            .nullable(false);
1308
1309                            Ok((agg_input, agg_input_type))
1310                        }
1311                    };
1312
1313                    match expr.func {
1314                        WindowExprType::Scalar(scalar_window_expr) => {
1315                            let mir_aggr_func = scalar_window_expr.into_expr();
1316                            Self::window_func_applied_to(
1317                                id_gen,
1318                                col_map,
1319                                cte_map,
1320                                inner,
1321                                subquery_map,
1322                                partition_by,
1323                                order_by,
1324                                mir_aggr_func,
1325                                scalar_lower_args,
1326                                context,
1327                            )?
1328                        }
1329                        WindowExprType::Value(value_window_expr) => {
1330                            let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1331
1332                            Self::window_func_applied_to(
1333                                id_gen,
1334                                col_map,
1335                                cte_map,
1336                                inner,
1337                                subquery_map,
1338                                partition_by,
1339                                order_by,
1340                                mir_aggr_func,
1341                                value_or_aggr_lower_args(hir_encoded_args),
1342                                context,
1343                            )?
1344                        }
1345                        WindowExprType::Aggregate(aggr_window_expr) => {
1346                            let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1347
1348                            Self::window_func_applied_to(
1349                                id_gen,
1350                                col_map,
1351                                cte_map,
1352                                inner,
1353                                subquery_map,
1354                                partition_by,
1355                                order_by,
1356                                mir_aggr_func,
1357                                value_or_aggr_lower_args(hir_encoded_args),
1358                                context,
1359                            )?
1360                        }
1361                    }
1362                }
1363            })
1364        })
1365    }
1366
1367    fn window_func_applied_to<F>(
1368        id_gen: &mut mz_ore::id_gen::IdGen,
1369        col_map: &ColumnMap,
1370        cte_map: &mut CteMap,
1371        inner: &mut MirRelationExpr,
1372        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1373        partition_by: Vec<HirScalarExpr>,
1374        order_by: Vec<HirScalarExpr>,
1375        mir_aggr_func: AggregateFunc,
1376        lower_args: F,
1377        context: &Context,
1378    ) -> Result<MirScalarExpr, PlanError>
1379    where
1380        F: FnOnce(
1381            &mut mz_ore::id_gen::IdGen,
1382            &ColumnMap,
1383            &mut CteMap,
1384            &mut MirRelationExpr,
1385            &Option<&BTreeMap<HirScalarExpr, usize>>,
1386            Vec<MirScalarExpr>,
1387            MirScalarExpr,
1388            SqlScalarType,
1389        ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1390    {
1391        // Example MIRs for a window function (specifically, a window aggregation):
1392        //
1393        // CREATE TABLE t7(x INT, y INT);
1394        //
1395        // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1396        //
1397        // Decorrelated Plan
1398        // Project (#3)
1399        //   Map (#2)
1400        //     Project (#3..=#5)
1401        //       Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1402        //         FlatMap unnest_list(#1)
1403        //           Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1404        //             Map ((#0 + #1))
1405        //               CrossJoin
1406        //                 Constant
1407        //                   - ()
1408        //                 Get materialize.public.t7
1409        //
1410        // The same query after optimizations:
1411        //
1412        // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1413        //
1414        // Optimized Plan
1415        // Explained Query:
1416        //   Project (#2)
1417        //     Map (record_get[0](#1))
1418        //       FlatMap unnest_list(#0)
1419        //         Project (#1)
1420        //           Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1421        //             ReadStorage materialize.public.t7
1422        //
1423        // The `row(row(row(...), ...), ...)` stuff means the following:
1424        // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1425        //   - The <arguments to window function> can be either a single value or itself a
1426        //     `row` if there are multiple arguments.
1427        //   - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1428        //     ORDER BY columns.
1429        //   - The <original row> currently always captures the entire original row. This should
1430        //     improve when we make `ProjectionPushdown` smarter, see
1431        //     https://github.com/MaterializeInc/database-issues/issues/5090
1432        //
1433        // TODO:
1434        // We should probably introduce some dedicated Datum constructor functions instead of `row`
1435        // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1436        // might even introduce dedicated Datum enum variants, so that the rendering code also
1437        // becomes more readable (and possibly slightly more performant).
1438
1439        *inner = inner
1440            .take_dangerous()
1441            .let_in(id_gen, |id_gen, mut get_inner| {
1442                let order_by_mir = order_by
1443                    .into_iter()
1444                    .map(|o| {
1445                        o.applied_to(
1446                            id_gen,
1447                            col_map,
1448                            cte_map,
1449                            &mut get_inner,
1450                            subquery_map,
1451                            context,
1452                        )
1453                    })
1454                    .collect::<Result<Vec<_>, _>>()?;
1455
1456                // Record input arity here so that any group_keys that need to mutate get_inner
1457                // don't add those columns to the aggregate input.
1458                let input_type = get_inner.sql_typ();
1459                let input_arity = input_type.arity();
1460                // The reduction that computes the window function must be keyed on the columns
1461                // from the outer context, plus the expressions in the partition key. The current
1462                // subquery will be 'executed' for every distinct row from the outer context so
1463                // by putting the outer columns in the grouping key we isolate each re-execution.
1464                let mut group_key = col_map
1465                    .inner
1466                    .iter()
1467                    .map(|(_, outer_col)| *outer_col)
1468                    .sorted()
1469                    .collect_vec();
1470                for p in partition_by {
1471                    let key = p.applied_to(
1472                        id_gen,
1473                        col_map,
1474                        cte_map,
1475                        &mut get_inner,
1476                        subquery_map,
1477                        context,
1478                    )?;
1479                    if let MirScalarExpr::Column(c, _name) = key {
1480                        group_key.push(c);
1481                    } else {
1482                        get_inner = get_inner.map_one(key);
1483                        group_key.push(get_inner.arity() - 1);
1484                    }
1485                }
1486
1487                get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1488                    // Original columns of the relation
1489                    let fields: Box<_> = input_type
1490                        .column_types
1491                        .iter()
1492                        .take(input_arity)
1493                        .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1494                        .collect();
1495
1496                    // Original row made into a record
1497                    let original_row_record = MirScalarExpr::call_variadic(
1498                        variadic::RecordCreate {
1499                            field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1500                        },
1501                        (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1502                    );
1503                    let original_row_record_type = SqlScalarType::Record {
1504                        fields,
1505                        custom_id: None,
1506                    };
1507
1508                    let (agg_input, agg_input_type) = lower_args(
1509                        id_gen,
1510                        col_map,
1511                        cte_map,
1512                        &mut get_inner,
1513                        subquery_map,
1514                        order_by_mir,
1515                        original_row_record,
1516                        original_row_record_type,
1517                    )?;
1518
1519                    let aggregate = mz_expr::AggregateExpr {
1520                        func: mir_aggr_func,
1521                        expr: agg_input,
1522                        distinct: false,
1523                    };
1524
1525                    // Actually call reduce with the window function
1526                    // The output of the aggregation function should be a list of tuples that has
1527                    // the result in the first position, and the original row in the second position
1528                    let mut reduce = get_inner
1529                        .reduce(group_key.clone(), vec![aggregate.clone()], None)
1530                        .flat_map(
1531                            mz_expr::TableFunc::UnnestList {
1532                                el_typ: aggregate
1533                                    .func
1534                                    .output_sql_type(agg_input_type)
1535                                    .scalar_type
1536                                    .unwrap_list_element_type()
1537                                    .clone(),
1538                            },
1539                            vec![MirScalarExpr::column(group_key.len())],
1540                        );
1541                    let record_col = reduce.arity() - 1;
1542
1543                    // Unpack the record output by the window function
1544                    for c in 0..input_arity {
1545                        reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1546                            func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1547                            expr: Box::new(MirScalarExpr::CallUnary {
1548                                func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1549                                expr: Box::new(MirScalarExpr::column(record_col)),
1550                            }),
1551                        });
1552                    }
1553
1554                    // Append the column with the result of the window function.
1555                    reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1556                        func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1557                        expr: Box::new(MirScalarExpr::column(record_col)),
1558                    });
1559
1560                    let agg_col = record_col + 1 + input_arity;
1561                    Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1562                })
1563            })?;
1564        Ok(MirScalarExpr::column(inner.arity() - 1))
1565    }
1566
1567    /// Applies the subqueries in the given list of scalar expressions to every distinct
1568    /// value of the given relation and returns a join of the given relation with all
1569    /// the subqueries found, and the mapping of scalar expressions with columns projected
1570    /// by the returned join that will hold their results.
1571    fn lower_subqueries(
1572        exprs: &[Self],
1573        id_gen: &mut mz_ore::id_gen::IdGen,
1574        col_map: &ColumnMap,
1575        cte_map: &mut CteMap,
1576        inner: MirRelationExpr,
1577        context: &Context,
1578    ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1579        let mut subquery_map = BTreeMap::new();
1580        let output = inner.let_in(id_gen, |id_gen, get_inner| {
1581            let mut subqueries = Vec::new();
1582            let distinct_inner = get_inner.clone().distinct();
1583            for expr in exprs.iter() {
1584                expr.visit_pre_post(
1585                    &mut |e| match e {
1586                        // For simplicity, subqueries within a conditional statement will be
1587                        // lowered when lowering the conditional expression.
1588                        HirScalarExpr::If { .. } => Some(vec![]),
1589                        _ => None,
1590                    },
1591                    &mut |e| match e {
1592                        HirScalarExpr::Select(expr, _name) => {
1593                            let apply_requires_distinct_outer = false;
1594                            let subquery = apply_scalar_subquery(
1595                                id_gen,
1596                                distinct_inner.clone(),
1597                                col_map,
1598                                cte_map,
1599                                (**expr).clone(),
1600                                apply_requires_distinct_outer,
1601                                context,
1602                            )
1603                            .unwrap();
1604
1605                            subqueries.push((e.clone(), subquery));
1606                        }
1607                        HirScalarExpr::Exists(expr, _name) => {
1608                            let apply_requires_distinct_outer = false;
1609                            let subquery = apply_existential_subquery(
1610                                id_gen,
1611                                distinct_inner.clone(),
1612                                col_map,
1613                                cte_map,
1614                                (**expr).clone(),
1615                                apply_requires_distinct_outer,
1616                                context,
1617                            )
1618                            .unwrap();
1619                            subqueries.push((e.clone(), subquery));
1620                        }
1621                        _ => {}
1622                    },
1623                )?;
1624            }
1625
1626            if subqueries.is_empty() {
1627                Ok::<MirRelationExpr, PlanError>(get_inner)
1628            } else {
1629                let inner_arity = get_inner.arity();
1630                let mut total_arity = inner_arity;
1631                let mut join_inputs = vec![get_inner];
1632                let mut join_input_arities = vec![inner_arity];
1633                for (expr, subquery) in subqueries.into_iter() {
1634                    // Avoid lowering duplicated subqueries
1635                    if !subquery_map.contains_key(&expr) {
1636                        let subquery_arity = subquery.arity();
1637                        assert_eq!(subquery_arity, inner_arity + 1);
1638                        join_inputs.push(subquery);
1639                        join_input_arities.push(subquery_arity);
1640                        total_arity += subquery_arity;
1641
1642                        // Column with the value of the subquery
1643                        subquery_map.insert(expr, total_arity - 1);
1644                    }
1645                }
1646                // Each subquery projects all the columns of the outer context (distinct_inner)
1647                // plus 1 column, containing the result of the subquery. Those columns must be
1648                // joined with the outer/main relation (get_inner).
1649                let input_mapper =
1650                    mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1651                let equivalences = (0..inner_arity)
1652                    .map(|col| {
1653                        join_inputs
1654                            .iter()
1655                            .enumerate()
1656                            .map(|(input, _)| {
1657                                MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1658                            })
1659                            .collect_vec()
1660                    })
1661                    .collect_vec();
1662                Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1663            }
1664        })?;
1665        Ok((output, subquery_map))
1666    }
1667
1668    /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1669    ///
1670    /// Returns an _internal_ error if the expression contains
1671    /// - a subquery
1672    /// - a column reference to an outer level
1673    /// - a parameter
1674    /// - a window function call
1675    ///
1676    /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1677    ///
1678    /// Set `enable_cast_elimination` to remove casts that are noops in MIR.
1679    pub fn lower_uncorrelated<C: Into<Config>>(
1680        self,
1681        config: C,
1682    ) -> Result<MirScalarExpr, PlanError> {
1683        let config = config.into();
1684
1685        use MirScalarExpr as SS;
1686
1687        use HirScalarExpr::*;
1688
1689        Ok(match self {
1690            Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1691            Literal(datum, typ, _name) => SS::Literal(Ok(datum), ReprColumnType::from(&typ)),
1692            CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1693            CallUnary {
1694                func,
1695                expr,
1696                name: _,
1697            } => {
1698                let inner = expr.lower_uncorrelated(config)?;
1699
1700                if config.enable_cast_elimination && func.is_eliminable_cast() {
1701                    inner
1702                } else {
1703                    SS::CallUnary {
1704                        func,
1705                        expr: Box::new(inner),
1706                    }
1707                }
1708            }
1709            CallBinary {
1710                func,
1711                expr1,
1712                expr2,
1713                name: _,
1714            } => SS::CallBinary {
1715                func,
1716                expr1: Box::new(expr1.lower_uncorrelated(config)?),
1717                expr2: Box::new(expr2.lower_uncorrelated(config)?),
1718            },
1719            CallVariadic {
1720                func,
1721                exprs,
1722                name: _,
1723            } => SS::call_variadic(
1724                func,
1725                exprs
1726                    .into_iter()
1727                    .map(|expr| expr.lower_uncorrelated(config))
1728                    .collect::<Result<_, _>>()?,
1729            ),
1730            If {
1731                cond,
1732                then,
1733                els,
1734                name: _,
1735            } => SS::If {
1736                cond: Box::new(cond.lower_uncorrelated(config)?),
1737                then: Box::new(then.lower_uncorrelated(config)?),
1738                els: Box::new(els.lower_uncorrelated(config)?),
1739            },
1740            Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1741                sql_bail!(
1742                    "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1743                    self
1744                );
1745            }
1746        })
1747    }
1748}
1749
1750/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1751/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1752/// will, in effect, be executed once for every distinct row in `outer`, and the
1753/// results will be joined with `outer`. Note that columns in `outer` that are
1754/// not depended upon by `inner` are thrown away before the distinct, so that we
1755/// don't perform needless computation of `inner`.
1756///
1757/// `branch` will inspect the contents of `inner` to determine whether `inner`
1758/// is not multiplicity sensitive (roughly, contains only maps, filters,
1759/// projections, and calls to table functions). If it is not multiplicity
1760/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1761/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1762/// operations that were not present in `inner`, then set
1763/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1764/// that distinctifies `outer`.
1765///
1766/// The caller must supply the `apply` function that applies the rewritten
1767/// `inner` to `outer`.
1768fn branch<F>(
1769    id_gen: &mut mz_ore::id_gen::IdGen,
1770    outer: MirRelationExpr,
1771    col_map: &ColumnMap,
1772    cte_map: &mut CteMap,
1773    inner: HirRelationExpr,
1774    apply_requires_distinct_outer: bool,
1775    context: &Context,
1776    apply: F,
1777) -> Result<MirRelationExpr, PlanError>
1778where
1779    F: FnOnce(
1780        &mut mz_ore::id_gen::IdGen,
1781        HirRelationExpr,
1782        MirRelationExpr,
1783        &ColumnMap,
1784        &mut CteMap,
1785        &Context,
1786    ) -> Result<MirRelationExpr, PlanError>,
1787{
1788    // TODO: It would be nice to have a version of this code w/o optimizations,
1789    // at the least for purposes of understanding. It was difficult for one reader
1790    // to understand the required properties of `outer` and `col_map`.
1791
1792    // If the inner expression is sufficiently simple, it is safe to apply it
1793    // *directly* to outer, rather than applying it to the distinctified key
1794    // (see below).
1795    //
1796    // As an example, consider the following two queries:
1797    //
1798    //     CREATE TABLE t (a int, b int);
1799    //     SELECT a, series FROM t, generate_series(1, t.b) series;
1800    //
1801    // The "simple" path for the `SELECT` yields
1802    //
1803    //     %0 =
1804    //     | Get t
1805    //     | FlatMap generate_series(1, #1)
1806    //
1807    // while the non-simple path yields:
1808    //
1809    //    %0 =
1810    //    | Get t
1811    //
1812    //    %1 =
1813    //    | Get t
1814    //    | Distinct group=(#1)
1815    //    | FlatMap generate_series(1, #0)
1816    //
1817    //    %2 =
1818    //    | LeftJoin %1 %2 (= #1 #2)
1819    //
1820    // There is a tradeoff here: the simple plan is stateless, but the non-
1821    // simple plan may do (much) less computation if there are only a few
1822    // distinct values of `t.b`.
1823    //
1824    // We apply a very simple heuristic here and take the simple path if `inner`
1825    // contains only maps, filters, projections, and calls to table functions.
1826    // The intuition is that straightforward usage of table functions should
1827    // take the simple path, while everything else should not. (In theory we
1828    // think this transformation is valid as long as `inner` does not contain a
1829    // Reduce, Distinct, or TopK node, but it is not always an optimization in
1830    // the general case.)
1831    //
1832    // TODO(benesch): this should all be handled by a proper optimizer, but
1833    // detecting the moment of decorrelation in the optimizer right now is too
1834    // hard.
1835    let mut is_simple = true;
1836    #[allow(deprecated)]
1837    inner.visit(0, &mut |expr, _| match expr {
1838        HirRelationExpr::Constant { .. }
1839        | HirRelationExpr::Project { .. }
1840        | HirRelationExpr::Map { .. }
1841        | HirRelationExpr::Filter { .. }
1842        | HirRelationExpr::CallTable { .. } => (),
1843        _ => is_simple = false,
1844    });
1845    if is_simple && !apply_requires_distinct_outer {
1846        let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1847        return outer.let_in(id_gen, |id_gen, get_outer| {
1848            apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1849        });
1850    }
1851
1852    // The key consists of the columns from the outer expression upon which the
1853    // inner relation depends. We discover these dependencies by walking the
1854    // inner relation expression and looking for column references whose level
1855    // escapes inner.
1856    //
1857    // At the end of this process, `key` contains the decorrelated position of
1858    // each outer column, according to the passed-in `col_map`, and
1859    // `new_col_map` maps each outer column to its new ordinal position in key.
1860    let mut outer_cols = BTreeSet::new();
1861    #[allow(deprecated)]
1862    inner.visit_columns(0, &mut |depth, col| {
1863        // Test if the column reference escapes the subquery.
1864        if col.level > depth {
1865            outer_cols.insert(ColumnRef {
1866                level: col.level - depth,
1867                column: col.column,
1868            });
1869        }
1870    });
1871    // Collect all the outer columns referenced by any CTE referenced by
1872    // the inner relation.
1873    #[allow(deprecated)]
1874    inner.visit(0, &mut |e, _| match e {
1875        HirRelationExpr::Get {
1876            id: mz_expr::Id::Local(id),
1877            ..
1878        } => {
1879            if let Some(cte_desc) = cte_map.get(id) {
1880                let cte_outer_arity = cte_desc.outer_relation.arity();
1881                outer_cols.extend(
1882                    col_map
1883                        .inner
1884                        .iter()
1885                        .filter(|(_, position)| **position < cte_outer_arity)
1886                        .map(|(c, _)| {
1887                            // `col_map` maps column references to column positions in
1888                            // `outer`'s projection.
1889                            // `outer_cols` is meant to contain the external column
1890                            // references in `inner`.
1891                            // Since `inner` defines a new scope, any column reference
1892                            // in `col_map` is one level deeper when seen from within
1893                            // `inner`, hence the +1.
1894                            ColumnRef {
1895                                level: c.level + 1,
1896                                column: c.column,
1897                            }
1898                        }),
1899                );
1900            }
1901        }
1902        HirRelationExpr::Let { id, .. } => {
1903            // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1904            // we would need to remove the old CTE with the same ID temporarily while
1905            // traversing the definition of the new CTE under the same ID.
1906            assert!(!cte_map.contains_key(id));
1907        }
1908        _ => {}
1909    });
1910    let mut new_col_map = BTreeMap::new();
1911    let mut key = vec![];
1912    for col in outer_cols {
1913        new_col_map.insert(col, key.len());
1914        key.push(col_map.get(&ColumnRef {
1915            // Note: `outer_cols` contains the external column references within `inner`.
1916            // We must compensate for `inner`'s scope when translating column references
1917            // as seen within `inner` to column references as seen from `outer`'s context,
1918            // hence the -1.
1919            level: col.level - 1,
1920            column: col.column,
1921        }));
1922    }
1923    let new_col_map = ColumnMap::new(new_col_map);
1924    outer.let_in(id_gen, |id_gen, get_outer| {
1925        let keyed_outer = if key.is_empty() {
1926            // Don't depend on outer at all if the branch is not correlated,
1927            // which yields vastly better query plans. Note that this is a bit
1928            // weird in that the branch will be computed even if outer has no
1929            // rows, whereas if it had been correlated it would not (and *could*
1930            // not) have been computed if outer had no rows, but the callers of
1931            // this function don't mind these somewhat-weird semantics.
1932            MirRelationExpr::constant(vec![vec![]], ReprRelationType::new(vec![]))
1933        } else {
1934            get_outer.clone().distinct_by(key.clone())
1935        };
1936        keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1937            let oa = get_outer.arity();
1938            let branch = apply(
1939                id_gen,
1940                inner,
1941                get_keyed_outer,
1942                &new_col_map,
1943                cte_map,
1944                context,
1945            )?;
1946            let ba = branch.arity();
1947            let joined = MirRelationExpr::join(
1948                vec![get_outer.clone(), branch],
1949                key.iter()
1950                    .enumerate()
1951                    .map(|(i, &k)| vec![(0, k), (1, i)])
1952                    .collect(),
1953            )
1954            // throw away the right-hand copy of the key we just joined on
1955            .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1956            Ok(joined)
1957        })
1958    })
1959}
1960
1961fn apply_scalar_subquery(
1962    id_gen: &mut mz_ore::id_gen::IdGen,
1963    outer: MirRelationExpr,
1964    col_map: &ColumnMap,
1965    cte_map: &mut CteMap,
1966    scalar_subquery: HirRelationExpr,
1967    apply_requires_distinct_outer: bool,
1968    context: &Context,
1969) -> Result<MirRelationExpr, PlanError> {
1970    branch(
1971        id_gen,
1972        outer,
1973        col_map,
1974        cte_map,
1975        scalar_subquery,
1976        apply_requires_distinct_outer,
1977        context,
1978        |id_gen, expr, get_inner, col_map, cte_map, context| {
1979            // compute for every row in get_inner
1980            let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1981            let col_type = select.sql_typ().column_types.into_last();
1982
1983            let inner_arity = get_inner.arity();
1984            // We must determine a count for each `get_inner` prefix,
1985            // and report an error if that count exceeds one.
1986            let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1987                // Count for each `get_inner` prefix.
1988                let counts = get_select.clone().reduce(
1989                    (0..inner_arity).collect::<Vec<_>>(),
1990                    vec![mz_expr::AggregateExpr {
1991                        func: mz_expr::AggregateFunc::Count,
1992                        expr: MirScalarExpr::literal_true(),
1993                        distinct: false,
1994                    }],
1995                    None,
1996                );
1997
1998                // Errors should result from counts > 1.
1999                let errors = counts
2000                    .flat_map(
2001                        mz_expr::TableFunc::GuardSubquerySize {
2002                            column_type: col_type.clone().scalar_type,
2003                        },
2004                        vec![MirScalarExpr::column(inner_arity)],
2005                    )
2006                    .project(
2007                        (0..inner_arity)
2008                            .chain(Some(inner_arity + 1))
2009                            .collect::<Vec<_>>(),
2010                    );
2011                // Return `get_select` and any errors added in.
2012                Ok::<_, PlanError>(get_select.union(errors))
2013            })?;
2014            // append Null to anything that didn't return any rows
2015            let default = vec![(Datum::Null, ReprScalarType::from(&col_type.scalar_type))];
2016            get_inner.lookup(id_gen, guarded, default)
2017        },
2018    )
2019}
2020
2021fn apply_existential_subquery(
2022    id_gen: &mut mz_ore::id_gen::IdGen,
2023    outer: MirRelationExpr,
2024    col_map: &ColumnMap,
2025    cte_map: &mut CteMap,
2026    subquery_expr: HirRelationExpr,
2027    apply_requires_distinct_outer: bool,
2028    context: &Context,
2029) -> Result<MirRelationExpr, PlanError> {
2030    branch(
2031        id_gen,
2032        outer,
2033        col_map,
2034        cte_map,
2035        subquery_expr,
2036        apply_requires_distinct_outer,
2037        context,
2038        |id_gen, expr, get_inner, col_map, cte_map, context| {
2039            let exists = expr
2040                // compute for every row in get_inner
2041                .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2042                // throw away actual values and just remember whether or not there were __any__ rows
2043                .distinct_by((0..get_inner.arity()).collect())
2044                // Append true to anything that returned any rows.
2045                .map(vec![MirScalarExpr::literal_true()]);
2046
2047            // append False to anything that didn't return any rows
2048            get_inner.lookup(id_gen, exists, vec![(Datum::False, ReprScalarType::Bool)])
2049        },
2050    )
2051}
2052
2053impl AggregateExpr {
2054    fn applied_to(
2055        self,
2056        id_gen: &mut mz_ore::id_gen::IdGen,
2057        col_map: &ColumnMap,
2058        cte_map: &mut CteMap,
2059        inner: &mut MirRelationExpr,
2060        context: &Context,
2061    ) -> Result<mz_expr::AggregateExpr, PlanError> {
2062        let AggregateExpr {
2063            func,
2064            expr,
2065            distinct,
2066        } = self;
2067
2068        Ok(mz_expr::AggregateExpr {
2069            func: func.into_expr(),
2070            expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2071            distinct,
2072        })
2073    }
2074}
2075
2076/// Attempts an efficient outer join, if `on` has equijoin structure.
2077///
2078/// Both `left` and `right` are decorrelated inputs.
2079///
2080/// The first `oa` columns correspond to an outer context: we should do the
2081/// outer join independently for each prefix. In the case that `on` contains
2082/// just some equality tests between columns of `left` and `right` and some
2083/// local predicates, we can employ a relatively simple plan.
2084///
2085/// The last `on_subquery_types.len()` columns correspond to results from
2086/// subqueries defined in the `on` clause - we treat those as theta-join
2087/// conditions that prohibit the use of the simple plan attempted here.
2088fn attempt_outer_equijoin(
2089    left: MirRelationExpr,
2090    right: MirRelationExpr,
2091    on: MirScalarExpr,
2092    on_subquery_types: Vec<ReprColumnType>,
2093    kind: JoinKind,
2094    oa: usize,
2095    id_gen: &mut mz_ore::id_gen::IdGen,
2096    context: &Context,
2097) -> Result<Option<MirRelationExpr>, PlanError> {
2098    // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2099    // predicates that reference subqueries as long as these subqueries don't
2100    // reference `left` and `right` at the same time.
2101    //
2102    // TODO(database-issues#6828): This code can be improved as follows:
2103    //
2104    // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2105    // 2. Use the canonicalized `on` predicate in the non-equijoin based
2106    //    lowering strategy.
2107    // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2108    // 4. Pass the classified `OnPredicates` as a parameter.
2109    // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2110    //
2111    // Steps (1 + 2) require further investigation because we might change the
2112    // error semantics in case the `on` predicate contains a literal error..
2113
2114    let l_type = left.typ();
2115    let r_type = right.typ();
2116    let la = l_type.column_types.len() - oa;
2117    let ra = r_type.column_types.len() - oa;
2118    let sa = on_subquery_types.len();
2119
2120    // The output type contains [outer, left, right, sa] attributes.
2121    let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2122    output_type.extend(l_type.column_types);
2123    output_type.extend(r_type.column_types.into_iter().skip(oa));
2124    output_type.extend(on_subquery_types);
2125
2126    // Generally healthy to do, but specifically `USING` conditions sometimes
2127    // put an `AND true` at the end of the `ON` condition.
2128    //
2129    // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2130    // However, in that case it's not clear that we won't see regressions if
2131    // `on` simplifies to a literal error.
2132    let mut on = vec![on];
2133    mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2134
2135    // Form the left and right types without the outer attributes.
2136    output_type.drain(0..oa);
2137    let lt = output_type.drain(0..la).collect_vec();
2138    let rt = output_type.drain(0..ra).collect_vec();
2139    assert!(output_type.len() == sa);
2140
2141    let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2142    if !on_predicates.is_equijoin(context) {
2143        return Ok(None);
2144    }
2145
2146    // If we've gotten this far, we can do the clever thing.
2147    // We'll want to use left and right multiple times
2148    let result = left.let_in(id_gen, |id_gen, get_left| {
2149        right.let_in(id_gen, |id_gen, get_right| {
2150            // TODO: we know that we can re-use the arrangements of left and right
2151            // needed for the inner join with each of the conditional outer joins.
2152            // It is not clear whether we should hint that, or just let the planner
2153            // and optimizer run and see what happens.
2154
2155            // We'll want the inner join (minus repeated columns)
2156            let join = MirRelationExpr::join(
2157                vec![get_left.clone(), get_right.clone()],
2158                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2159            )
2160            // remove those columns from `right` repeating the first `oa` columns.
2161            .project(
2162                (0..(oa + la))
2163                    .chain((oa + la + oa)..(oa + la + oa + ra))
2164                    .collect(),
2165            )
2166            // apply the filter constraints here, to ensure nulls are not matched.
2167            .filter(on);
2168
2169            // We'll want to re-use the results of the join multiple times.
2170            join.let_in(id_gen, |id_gen, get_join| {
2171                let mut result = get_join.clone();
2172
2173                // A collection of keys present in both left and right collections.
2174                let join_keys = on_predicates.join_keys();
2175                let both_keys_arity = join_keys.len();
2176                let both_keys = get_join.restrict(join_keys).distinct();
2177
2178                // The plan is now to determine the left and right rows matched in the
2179                // inner join, subtract them from left and right respectively, pad what
2180                // remains with nulls, and fold them in to `result`.
2181
2182                both_keys.let_in(id_gen, |_id_gen, get_both| {
2183                    if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2184                        // Rows in `left` matched in the inner equijoin. This is
2185                        // a semi-join between `left` and `both_keys`.
2186                        let left_present = MirRelationExpr::join_scalars(
2187                            vec![
2188                                get_left
2189                                    .clone()
2190                                    // Push local predicates.
2191                                    .filter(on_predicates.lhs()),
2192                                get_both.clone(),
2193                            ],
2194                            itertools::zip_eq(
2195                                on_predicates.eq_lhs(),
2196                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2197                            )
2198                            .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2199                            .collect(),
2200                        )
2201                        .project((0..(oa + la)).collect());
2202
2203                        // Determine the types of nulls to use as filler.
2204                        let right_fill = rt
2205                            .into_iter()
2206                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2207                            .collect();
2208                        // Add to `result` absent elements, filled with typed nulls.
2209                        result = left_present
2210                            .negate()
2211                            .union(get_left.clone())
2212                            .map(right_fill)
2213                            .union(result);
2214                    }
2215
2216                    if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2217                        // Rows in `right` matched in the inner equijoin. This
2218                        // is a semi-join between `right` and `both_keys`.
2219                        let right_present = MirRelationExpr::join_scalars(
2220                            vec![
2221                                get_right
2222                                    .clone()
2223                                    // Push local predicates.
2224                                    .filter(on_predicates.rhs()),
2225                                get_both,
2226                            ],
2227                            itertools::zip_eq(
2228                                on_predicates.eq_rhs(),
2229                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2230                            )
2231                            .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2232                            .collect(),
2233                        )
2234                        .project((0..(oa + ra)).collect());
2235
2236                        // Determine the types of nulls to use as filler.
2237                        let left_fill = lt
2238                            .into_iter()
2239                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2240                            .collect();
2241
2242                        // Add to `result` absent elements, prepended with typed nulls.
2243                        result = right_present
2244                            .negate()
2245                            .union(get_right.clone())
2246                            .map(left_fill)
2247                            // Permute left fill before right values.
2248                            .project(
2249                                itertools::chain!(
2250                                    0..oa,                 // Preserve `outer`.
2251                                    oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2252                                    oa..oa + ra            // Decrement the next `ra` cols by `la`.
2253                                )
2254                                .collect(),
2255                            )
2256                            .union(result)
2257                    }
2258
2259                    Ok::<_, PlanError>(result)
2260                })
2261            })
2262        })
2263    })?;
2264    Ok(Some(result))
2265}
2266
2267/// A struct that represents the predicates in the `on` clause in a form
2268/// suitable for efficient planning outer joins with equijoin predicates.
2269struct OnPredicates {
2270    /// A store for classified `ON` predicates.
2271    ///
2272    /// Predicates that reference a single side are adjusted to assume an
2273    /// `outer × <side>` schema.
2274    predicates: Vec<OnPredicate>,
2275    /// Number of outer context columns.
2276    oa: usize,
2277}
2278
2279impl OnPredicates {
2280    const I_OUT: usize = 0; // outer context input position
2281    const I_LHS: usize = 1; // lhs input position
2282    const I_RHS: usize = 2; // rhs input position
2283    const I_SUB: usize = 3; // on subqueries input position
2284
2285    /// Classify the predicates in the `on` clause of an outer join.
2286    ///
2287    /// The other parameters are arities of the input parts:
2288    ///
2289    /// - `oa` is the arity of the `outer` context.
2290    /// - `la` is the arity of the `left` input.
2291    /// - `ra` is the arity of the `right` input.
2292    /// - `sa` is the arity of the `on` subqueries.
2293    ///
2294    /// The constructor assumes that:
2295    ///
2296    /// 1. The `on` parameter will be applied on a result that has the following
2297    ///    schema `outer × left × right × on_subqueries`.
2298    /// 2. The `on` parameter is already adjusted to assume that schema.
2299    /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2300    ///    MirScalarExpr` with `canonicalize_predicates`.
2301    fn new(
2302        oa: usize,
2303        la: usize,
2304        ra: usize,
2305        sa: usize,
2306        on: Vec<MirScalarExpr>,
2307        _context: &Context,
2308    ) -> Self {
2309        use mz_expr::BinaryFunc::Eq;
2310
2311        // Re-bind those locally for more compact pattern matching.
2312        const I_LHS: usize = OnPredicates::I_LHS;
2313        const I_RHS: usize = OnPredicates::I_RHS;
2314
2315        // Self parameters.
2316        let mut predicates = Vec::with_capacity(on.len());
2317
2318        // Helpers for populating `predicates`.
2319        let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2320        let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2321        let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2322            inner_join_mapper
2323                .lookup_inputs(expr)
2324                .filter(|&i| i != Self::I_OUT)
2325                .collect()
2326        };
2327        let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2328            inner_join_mapper
2329                .lookup_inputs(expr)
2330                .any(|i| i == Self::I_SUB)
2331        };
2332
2333        // Iterate over `on` elements and populate `predicates`.
2334        for mut predicate in on {
2335            if predicate.might_error() {
2336                tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2337                // Treat predicates that can produce a literal error as Theta.
2338                predicates.push(OnPredicate::Theta(predicate));
2339            } else if has_subquery_refs(&predicate) {
2340                tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2341                // Treat predicates referencing an `on` subquery as Theta.
2342                predicates.push(OnPredicate::Theta(predicate));
2343            } else if let MirScalarExpr::CallBinary {
2344                func: Eq(_),
2345                expr1,
2346                expr2,
2347            } = &mut predicate
2348            {
2349                // Obtain the non-outer inputs referenced by each side.
2350                let inputs1 = lookup_inputs(expr1);
2351                let inputs2 = lookup_inputs(expr2);
2352
2353                match (&inputs1[..], &inputs2[..]) {
2354                    // Neither side references an input. This could be a
2355                    // constant expression or an expression that depends only on
2356                    // the outer context.
2357                    ([], []) => {
2358                        predicates.push(OnPredicate::Const(predicate));
2359                    }
2360                    // Both sides reference different inputs.
2361                    ([I_LHS], [I_RHS]) => {
2362                        let lhs = expr1.take();
2363                        let mut rhs = expr2.take();
2364                        rhs.permute(&rhs_permutation);
2365                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2366                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2367                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2368                    }
2369                    // Both sides reference different inputs (swapped).
2370                    ([I_RHS], [I_LHS]) => {
2371                        let lhs = expr2.take();
2372                        let mut rhs = expr1.take();
2373                        rhs.permute(&rhs_permutation);
2374                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2375                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2376                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2377                    }
2378                    // Both sides reference the left input or no input.
2379                    ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2380                        predicates.push(OnPredicate::Lhs(predicate));
2381                    }
2382                    // Both sides reference the right input or no input.
2383                    ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2384                        predicate.permute(&rhs_permutation);
2385                        predicates.push(OnPredicate::Rhs(predicate));
2386                    }
2387                    // At least one side references more than one input.
2388                    _ => {
2389                        tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2390                        predicates.push(OnPredicate::Theta(predicate));
2391                    }
2392                }
2393            } else {
2394                // Obtain the non-outer inputs referenced by this predicate.
2395                let inputs = lookup_inputs(&predicate);
2396
2397                match &inputs[..] {
2398                    // The predicate references no inputs. This could be a
2399                    // constant expression or an expression that depends only on
2400                    // the outer context.
2401                    [] => {
2402                        predicates.push(OnPredicate::Const(predicate));
2403                    }
2404                    // The predicate references only the left input.
2405                    [I_LHS] => {
2406                        predicates.push(OnPredicate::Lhs(predicate));
2407                    }
2408                    // The predicate references only the right input.
2409                    [I_RHS] => {
2410                        predicate.permute(&rhs_permutation);
2411                        predicates.push(OnPredicate::Rhs(predicate));
2412                    }
2413                    // The predicate references both inputs.
2414                    _ => {
2415                        tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2416                        predicates.push(OnPredicate::Theta(predicate));
2417                    }
2418                }
2419            }
2420        }
2421
2422        Self { predicates, oa }
2423    }
2424
2425    /// Check if the predicates can be lowered with an equijoin-based strategy.
2426    fn is_equijoin(&self, context: &Context) -> bool {
2427        // Count each `OnPredicate` variant in `self.predicates`.
2428        let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2429            self.predicates.iter().fold(
2430                (0, 0, 0, 0, 0, 0),
2431                |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2432                    (
2433                        const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2434                        lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2435                        rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2436                        eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2437                        eq_cols
2438                            + usize::from(matches!(
2439                                p,
2440                                OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column()
2441                            )),
2442                        theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2443                    )
2444                },
2445            );
2446
2447        let is_equijion = if context.config.enable_new_outer_join_lowering {
2448            // New classifier.
2449            eq_cnt > 0 && theta_cnt == 0
2450        } else {
2451            // Old classifier.
2452            eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2453        };
2454
2455        // Log an entry only if this is an equijoin according to the new classifier.
2456        if eq_cnt > 0 && theta_cnt == 0 {
2457            tracing::debug!(
2458                const_cnt,
2459                lhs_cnt,
2460                rhs_cnt,
2461                eq_cnt,
2462                eq_cols,
2463                theta_cnt,
2464                "OnPredicates::is_equijoin"
2465            );
2466        }
2467
2468        is_equijion
2469    }
2470
2471    /// Return an [`MirRelationExpr`] list that represents the keys for the
2472    /// equijoin. The list will contain the outer columns as a prefix.
2473    fn join_keys(&self) -> JoinKeys {
2474        // We could return either the `lhs` or the `rhs` of the keys used to
2475        // form the inner join as they are equated by the join condition.
2476        let join_keys = self.eq_lhs().collect::<Vec<_>>();
2477
2478        if join_keys.iter().all(|k| k.is_column()) {
2479            tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2480            JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2481        } else {
2482            tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2483            JoinKeys::Scalars(join_keys)
2484        }
2485    }
2486
2487    /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2488    /// conditions in the predicates list.
2489    ///
2490    /// The iterator will start with column references to the outer columns as a
2491    /// prefix.
2492    fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2493        itertools::chain(
2494            (0..self.oa).map(MirScalarExpr::column),
2495            self.predicates.iter().filter_map(|e| match e {
2496                OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2497                _ => None,
2498            }),
2499        )
2500    }
2501
2502    /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2503    /// conditions in the predicates list.
2504    ///
2505    /// The iterator will start with column references to the outer columns as a
2506    /// prefix.
2507    fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2508        itertools::chain(
2509            (0..self.oa).map(MirScalarExpr::column),
2510            self.predicates.iter().filter_map(|e| match e {
2511                OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2512                _ => None,
2513            }),
2514        )
2515    }
2516
2517    /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2518    /// [`OnPredicate::Const`] conditions in the predicates list.
2519    fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2520        self.predicates.iter().filter_map(|p| match p {
2521            // We treat Const predicates local to both inputs.
2522            OnPredicate::Const(p) => Some(p.clone()),
2523            OnPredicate::Lhs(p) => Some(p.clone()),
2524            OnPredicate::LhsConsequence(p) => Some(p.clone()),
2525            _ => None,
2526        })
2527    }
2528
2529    /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2530    /// [`OnPredicate::Const`] conditions in the predicates list.
2531    fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2532        self.predicates.iter().filter_map(|p| match p {
2533            // We treat Const predicates local to both inputs.
2534            OnPredicate::Const(p) => Some(p.clone()),
2535            OnPredicate::Rhs(p) => Some(p.clone()),
2536            OnPredicate::RhsConsequence(p) => Some(p.clone()),
2537            _ => None,
2538        })
2539    }
2540}
2541
2542enum OnPredicate {
2543    // A predicate that is either constant or references only outer columns.
2544    Const(MirScalarExpr),
2545    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2546    // and possibly outer columns.
2547    //
2548    // This is one of the original predicates from the ON clause.
2549    //
2550    // One _must_ apply this predicate.
2551    Lhs(MirScalarExpr),
2552    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2553    // and possibly outer columns.
2554    //
2555    // This is not one of the original predicates from the ON clause, but is just a consequence
2556    // of an original predicate in the ON clause, where the original predicate references both
2557    // inputs, but the consequence references only the left input.
2558    //
2559    // For example, the original predicate `input1.x = input2.a` has the consequence
2560    // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2561    // prevents null skew, and also makes more CSE opportunities available when the left input's key
2562    // doesn't have a NOT NULL constraint, saving us an arrangement.
2563    //
2564    // Applying the predicate is optional, because the original predicate will be applied anyway.
2565    LhsConsequence(MirScalarExpr),
2566    // A local predicate on the right-hand side of the join.
2567    //
2568    // This is one of the original predicates from the ON clause.
2569    //
2570    // One _must_ apply this predicate.
2571    Rhs(MirScalarExpr),
2572    // A consequence of an original ON predicate, see above.
2573    RhsConsequence(MirScalarExpr),
2574    // An equality predicate between the two sides.
2575    Eq(MirScalarExpr, MirScalarExpr),
2576    // a non-equality predicate between the two sides.
2577    #[allow(dead_code)]
2578    Theta(MirScalarExpr),
2579}
2580
2581/// A set of join keys referencing an input.
2582///
2583/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2584/// avoid changes (and thereby possible regressions) in plans that have equijoin
2585/// predicates consisting only of column refs.
2586///
2587/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2588/// able to get rid of this code, but as it stands `Map` simplification seems
2589/// more cumbersome than `Project` simplification, so do this just to be sure.
2590enum JoinKeys {
2591    // A predicate that is either constant or references only outer columns.
2592    Outputs(Vec<usize>),
2593    // A local predicate on the left-hand side of the join.
2594    Scalars(Vec<MirScalarExpr>),
2595}
2596
2597impl JoinKeys {
2598    fn len(&self) -> usize {
2599        match self {
2600            JoinKeys::Outputs(outputs) => outputs.len(),
2601            JoinKeys::Scalars(scalars) => scalars.len(),
2602        }
2603    }
2604}
2605
2606/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2607/// code.
2608trait LoweringExt {
2609    /// See [`MirRelationExpr::restrict`].
2610    fn restrict(self, join_keys: JoinKeys) -> Self;
2611}
2612
2613impl LoweringExt for MirRelationExpr {
2614    /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2615    fn restrict(self, join_keys: JoinKeys) -> Self {
2616        let num_keys = join_keys.len();
2617        match join_keys {
2618            JoinKeys::Outputs(outputs) => self.project(outputs),
2619            JoinKeys::Scalars(scalars) => {
2620                let input_arity = self.arity();
2621                let outputs = (input_arity..input_arity + num_keys).collect();
2622                self.map(scalars).project(outputs)
2623            }
2624        }
2625    }
2626}