mz_sql/plan/
lowering.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::visit::Visit;
44use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr};
45use mz_ore::collections::CollectionExt;
46use mz_ore::stack::maybe_grow;
47use mz_repr::*;
48
49use crate::optimizer_metrics::OptimizerMetrics;
50use crate::plan::hir::{
51    AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
52};
53use crate::plan::{PlanError, transform_hir};
54use crate::session::vars::SystemVars;
55
56mod variadic_left;
57
58/// Maps a leveled column reference to a specific column.
59///
60/// Leveled column references are nested, so that larger levels are
61/// found early in a record and level zero is found at the end.
62///
63/// The column map only stores references for levels greater than zero,
64/// and column references at level zero simply start at the first column
65/// after all prior references.
66#[derive(Debug, Clone)]
67struct ColumnMap {
68    inner: BTreeMap<ColumnRef, usize>,
69}
70
71impl ColumnMap {
72    fn empty() -> ColumnMap {
73        Self::new(BTreeMap::new())
74    }
75
76    fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
77        ColumnMap { inner }
78    }
79
80    fn get(&self, col_ref: &ColumnRef) -> usize {
81        if col_ref.level == 0 {
82            self.inner.len() + col_ref.column
83        } else {
84            self.inner[col_ref]
85        }
86    }
87
88    fn len(&self) -> usize {
89        self.inner.len()
90    }
91
92    /// Updates references in the `ColumnMap` for use in a nested scope. The
93    /// provided `arity` must specify the arity of the current scope.
94    fn enter_scope(&self, arity: usize) -> ColumnMap {
95        // From the perspective of the nested scope, all existing column
96        // references will be one level greater.
97        let existing = self
98            .inner
99            .clone()
100            .into_iter()
101            .update(|(col, _i)| col.level += 1);
102
103        // All columns in the current scope become explicit entries in the
104        // immediate parent scope.
105        let new = (0..arity).map(|i| {
106            (
107                ColumnRef {
108                    level: 1,
109                    column: i,
110                },
111                self.len() + i,
112            )
113        });
114
115        ColumnMap::new(existing.chain(new).collect())
116    }
117}
118
119/// Map with the CTEs currently in scope.
120type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
121
122/// Information about needed when finding a reference to a CTE in scope.
123#[derive(Clone)]
124struct CteDesc {
125    /// The new ID assigned to the lowered version of the CTE, which may not match
126    /// the ID of the input CTE.
127    new_id: mz_expr::LocalId,
128    /// The relation type of the CTE including the columns from the outer
129    /// context at the beginning.
130    relation_type: RelationType,
131    /// The outer relation the CTE was applied to.
132    outer_relation: MirRelationExpr,
133}
134
135#[derive(Debug)]
136pub struct Config {
137    /// Enable outer join lowering implemented in database-issues#6747.
138    pub enable_new_outer_join_lowering: bool,
139    /// Enable outer join lowering implemented in database-issues#7561.
140    pub enable_variadic_left_join_lowering: bool,
141}
142
143impl From<&SystemVars> for Config {
144    fn from(vars: &SystemVars) -> Self {
145        Self {
146            enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
147            enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
148        }
149    }
150}
151
152/// Context passed to the lowering. This is wired to most parts of the lowering.
153pub(crate) struct Context<'a> {
154    /// Feature flags affecting the behavior of lowering.
155    pub config: &'a Config,
156    /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
157    /// simply don't write metrics.
158    pub metrics: Option<&'a OptimizerMetrics>,
159}
160
161impl HirRelationExpr {
162    /// Rewrite `self` into a `MirRelationExpr`.
163    /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
164    #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
165    pub fn lower<C: Into<Config>>(
166        self,
167        config: C,
168        metrics: Option<&OptimizerMetrics>,
169    ) -> Result<MirRelationExpr, PlanError> {
170        let context = Context {
171            config: &config.into(),
172            metrics,
173        };
174        let result = match self {
175            // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
176            // to ensure that the downstream optimizer can easily bypass most
177            // irrelevant optimizations (e.g. reduce folding) for this expression
178            // without having to re-learn the fact that it is just a constant,
179            // as it would if the constant were wrapped in a Let-Get pair.
180            HirRelationExpr::Constant { rows, typ } => {
181                let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
182                MirRelationExpr::Constant {
183                    rows: Ok(rows),
184                    typ,
185                }
186            }
187            mut other => {
188                let mut id_gen = mz_ore::id_gen::IdGen::default();
189                transform_hir::split_subquery_predicates(&mut other)?;
190                transform_hir::try_simplify_quantified_comparisons(&mut other)?;
191                transform_hir::fuse_window_functions(&mut other, &context)?;
192                MirRelationExpr::constant(vec![vec![]], RelationType::new(vec![])).let_in(
193                    &mut id_gen,
194                    |id_gen, get_outer| {
195                        other.applied_to(
196                            id_gen,
197                            get_outer,
198                            &ColumnMap::empty(),
199                            &mut CteMap::new(),
200                            &context,
201                        )
202                    },
203                )?
204            }
205        };
206
207        mz_repr::explain::trace_plan(&result);
208
209        Ok(result)
210    }
211
212    /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
213    ///
214    /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
215    /// When `self` references columns of `get_outer`, much more work needs to occur.
216    ///
217    /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
218    /// perhaps not all of them. It should be used as the basis of resolving column references,
219    /// but care must be taken when adding new columns that `get_outer.arity()` is where they
220    /// will start, rather than any function of `col_map`.
221    ///
222    /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
223    /// assignment of values to outer rows.
224    fn applied_to(
225        self,
226        id_gen: &mut mz_ore::id_gen::IdGen,
227        get_outer: MirRelationExpr,
228        col_map: &ColumnMap,
229        cte_map: &mut CteMap,
230        context: &Context,
231    ) -> Result<MirRelationExpr, PlanError> {
232        maybe_grow(|| {
233            use MirRelationExpr as SR;
234
235            use HirRelationExpr::*;
236
237            if let MirRelationExpr::Get { .. } = &get_outer {
238            } else {
239                panic!(
240                    "get_outer: expected a MirRelationExpr::Get, found {:?}",
241                    get_outer
242                );
243            }
244            assert_eq!(col_map.len(), get_outer.arity());
245            Ok(match self {
246                Constant { rows, typ } => {
247                    // Constant expressions are not correlated with `get_outer`, and should be cross-products.
248                    get_outer.product(SR::Constant {
249                        rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
250                        typ,
251                    })
252                }
253                Get { id, typ } => match id {
254                    mz_expr::Id::Local(local_id) => {
255                        let cte_desc = cte_map.get(&local_id).unwrap();
256                        let get_cte = SR::Get {
257                            id: mz_expr::Id::Local(cte_desc.new_id.clone()),
258                            typ: cte_desc.relation_type.clone(),
259                            access_strategy: AccessStrategy::UnknownOrLocal,
260                        };
261                        if get_outer == cte_desc.outer_relation {
262                            // If the CTE was applied to the same exact relation, we can safely
263                            // return a `Get` relation.
264                            get_cte
265                        } else {
266                            // Otherwise, the new outer relation may contain more columns from some
267                            // intermediate scope placed between the definition of the CTE and this
268                            // reference of the CTE and/or more operations applied on top of the
269                            // outer relation.
270                            //
271                            // An example of the latter is the following query:
272                            //
273                            // SELECT *
274                            // FROM x,
275                            //      LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
276                            //              SELECT (SELECT m FROM a) FROM y) b;
277                            //
278                            // When the CTE is lowered, the outer relation is `Get x`. But then,
279                            // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
280                            // which has the same cardinality as `Get x`.
281                            //
282                            // In any case, `get_outer` is guaranteed to contain the columns of the
283                            // outer relation the CTE was applied to at its prefix. Since, we must
284                            // return a relation containing `get_outer`'s column at the beginning,
285                            // we must build a join between `get_outer` and `get_cte` on their common
286                            // columns.
287                            let oa = get_outer.arity();
288                            let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
289                            let equivalences = (0..cte_outer_columns)
290                                .map(|pos| {
291                                    vec![
292                                        MirScalarExpr::Column(pos),
293                                        MirScalarExpr::Column(pos + oa),
294                                    ]
295                                })
296                                .collect();
297
298                            // Project out the second copy of the common between `get_outer` and
299                            // `cte_desc.outer_relation`.
300                            let projection = (0..oa)
301                                .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
302                                .collect_vec();
303                            SR::join_scalars(vec![get_outer, get_cte], equivalences)
304                                .project(projection)
305                        }
306                    }
307                    _ => {
308                        // Get statements are only to external sources, and are not correlated with `get_outer`.
309                        get_outer.product(SR::Get {
310                            id,
311                            typ,
312                            access_strategy: AccessStrategy::UnknownOrLocal,
313                        })
314                    }
315                },
316                Let {
317                    name: _,
318                    id,
319                    value,
320                    body,
321                } => {
322                    let value =
323                        value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
324                    value.let_in(id_gen, |id_gen, get_value| {
325                        let (new_id, typ) = if let MirRelationExpr::Get {
326                            id: mz_expr::Id::Local(id),
327                            typ,
328                            ..
329                        } = get_value
330                        {
331                            (id, typ)
332                        } else {
333                            panic!(
334                                "get_value: expected a MirRelationExpr::Get with local Id, found {:?}",
335                                get_value
336                            );
337                        };
338                        // Add the information about the CTE to the map and remove it when
339                        // it goes out of scope.
340                        let old_value = cte_map.insert(
341                            id.clone(),
342                            CteDesc {
343                                new_id,
344                                relation_type: typ,
345                                outer_relation: get_outer.clone(),
346                            },
347                        );
348                        let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
349                        if let Some(old_value) = old_value {
350                            cte_map.insert(id, old_value);
351                        } else {
352                            cte_map.remove(&id);
353                        }
354                        body
355                    })?
356                }
357                LetRec {
358                    limit,
359                    bindings,
360                    body,
361                } => {
362                    let num_bindings = bindings.len();
363
364                    // We use the outer type with the HIR types to form MIR CTE types.
365                    let outer_column_types = get_outer.typ().column_types;
366
367                    // Rename and introduce all bindings.
368                    let mut shadowed_bindings = Vec::with_capacity(num_bindings);
369                    let mut mir_ids = Vec::with_capacity(num_bindings);
370                    for (_name, id, _value, typ) in bindings.iter() {
371                        let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
372                        mir_ids.push(mir_id);
373                        let shadowed = cte_map.insert(
374                            id.clone(),
375                            CteDesc {
376                                new_id: mir_id,
377                                relation_type: RelationType::new(
378                                    outer_column_types
379                                        .iter()
380                                        .cloned()
381                                        .chain(typ.column_types.iter().cloned())
382                                        .collect::<Vec<_>>(),
383                                ),
384                                outer_relation: get_outer.clone(),
385                            },
386                        );
387                        shadowed_bindings.push((*id, shadowed));
388                    }
389
390                    let mut mir_values = Vec::with_capacity(num_bindings);
391                    for (_name, _id, value, _typ) in bindings.into_iter() {
392                        mir_values.push(value.applied_to(
393                            id_gen,
394                            get_outer.clone(),
395                            col_map,
396                            cte_map,
397                            context,
398                        )?);
399                    }
400
401                    let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
402
403                    // Remove our bindings and reinstate any shadowed bindings.
404                    for (id, shadowed) in shadowed_bindings {
405                        if let Some(shadowed) = shadowed {
406                            cte_map.insert(id, shadowed);
407                        } else {
408                            cte_map.remove(&id);
409                        }
410                    }
411
412                    MirRelationExpr::LetRec {
413                        ids: mir_ids,
414                        values: mir_values,
415                        // Copy the limit to each binding.
416                        limits: repeat(limit).take(num_bindings).collect(),
417                        body: Box::new(mir_body),
418                    }
419                }
420                Project { input, outputs } => {
421                    // Projections should be applied to the decorrelated `inner`, and to its columns,
422                    // which means rebasing `outputs` to start `get_outer.arity()` columns later.
423                    let input =
424                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
425                    let outputs = (0..get_outer.arity())
426                        .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
427                        .collect::<Vec<_>>();
428                    input.project(outputs)
429                }
430                Map { input, mut scalars } => {
431                    // Scalar expressions may contain correlated subqueries. We must be cautious!
432
433                    // We lower scalars in chunks, and must keep track of the
434                    // arity of the HIR fragments lowered so far.
435                    let mut lowered_arity = input.arity();
436
437                    let mut input =
438                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
439
440                    // Lower subqueries in maximally sized batches, such as no subquery in the current
441                    // batch depends on columns from the same batch.
442                    // Note that subqueries in this projection may reference columns added by this
443                    // Map operator, so we need to ensure these columns exist before lowering the
444                    // subquery.
445                    while !scalars.is_empty() {
446                        let end_idx = scalars
447                            .iter_mut()
448                            .position(|s| {
449                                let mut requires_nonexistent_column = false;
450                                #[allow(deprecated)]
451                                s.visit_columns(0, &mut |depth, col| {
452                                    if col.level == depth {
453                                        requires_nonexistent_column |= col.column >= lowered_arity
454                                    }
455                                });
456                                requires_nonexistent_column
457                            })
458                            .unwrap_or(scalars.len());
459                        assert!(
460                            end_idx > 0,
461                            "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
462                            lowered_arity,
463                            scalars
464                        );
465
466                        lowered_arity = lowered_arity + end_idx;
467                        let scalars = scalars.drain(0..end_idx).collect_vec();
468
469                        let old_arity = input.arity();
470                        let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
471                            &scalars, id_gen, col_map, cte_map, input, context,
472                        )?;
473                        input = with_subqueries;
474
475                        // We will proceed sequentially through the scalar expressions, for each transforming
476                        // the decorrelated `input` into a relation with potentially more columns capable of
477                        // addressing the needs of the scalar expression.
478                        // Having done so, we add the scalar value of interest and trim off any other newly
479                        // added columns.
480                        //
481                        // The sequential traversal is present as expressions are allowed to depend on the
482                        // values of prior expressions.
483                        let mut scalar_columns = Vec::new();
484                        for scalar in scalars {
485                            let scalar = scalar.applied_to(
486                                id_gen,
487                                col_map,
488                                cte_map,
489                                &mut input,
490                                &Some(&subquery_map),
491                                context,
492                            )?;
493                            input = input.map_one(scalar);
494                            scalar_columns.push(input.arity() - 1);
495                        }
496
497                        // Discard any new columns added by the lowering of the scalar expressions
498                        input = input.project((0..old_arity).chain(scalar_columns).collect());
499                    }
500
501                    input
502                }
503                CallTable { func, exprs } => {
504                    // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
505                    // allowed to refer to the results of previous expressions, and we have a simpler
506                    // implementation that appends all relevant columns first, then applies the flatmap
507                    // operator to the result, then strips off any columns introduce by subqueries.
508
509                    let mut input = get_outer;
510                    let old_arity = input.arity();
511
512                    let exprs = exprs
513                        .into_iter()
514                        .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
515                        .collect::<Result<Vec<_>, _>>()?;
516
517                    let new_arity = input.arity();
518                    let output_arity = func.output_arity();
519                    input = input.flat_map(func, exprs);
520                    if old_arity != new_arity {
521                        // this means we added some columns to handle subqueries, and now we need to get rid of them
522                        input = input.project(
523                            (0..old_arity)
524                                .chain(new_arity..new_arity + output_arity)
525                                .collect(),
526                        );
527                    }
528                    input
529                }
530                Filter { input, predicates } => {
531                    // Filter expressions may contain correlated subqueries.
532                    // We extend `get_outer` with sufficient values to determine the value of the predicate,
533                    // then filter the results, then strip off any columns that were added for this purpose.
534                    let mut input =
535                        input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
536                    for predicate in predicates {
537                        let old_arity = input.arity();
538                        let predicate = predicate
539                            .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
540                        let new_arity = input.arity();
541                        input = input.filter(vec![predicate]);
542                        if old_arity != new_arity {
543                            // this means we added some columns to handle subqueries, and now we need to get rid of them
544                            input = input.project((0..old_arity).collect());
545                        }
546                    }
547                    input
548                }
549                Join {
550                    left,
551                    right,
552                    on,
553                    kind,
554                } if right.is_correlated() => {
555                    // A correlated join is a join in which the right expression has
556                    // access to the columns in the left expression. It turns out
557                    // this is *exactly* our branch operator, plus some additional
558                    // null handling in the case of left joins. (Right and full
559                    // lateral joins are not permitted.)
560                    //
561                    // As with normal joins, the `on` predicate may be correlated,
562                    // and we treat it as a filter that follows the branch.
563
564                    assert!(kind.can_be_correlated());
565
566                    let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
567                    left.let_in(id_gen, |id_gen, get_left| {
568                        let apply_requires_distinct_outer = false;
569                        let mut join = branch(
570                            id_gen,
571                            get_left.clone(),
572                            col_map,
573                            cte_map,
574                            *right,
575                            apply_requires_distinct_outer,
576                            context,
577                            |id_gen, right, get_left, col_map, cte_map, context| {
578                                right.applied_to(id_gen, get_left, col_map, cte_map, context)
579                            },
580                        )?;
581
582                        // Plan the `on` predicate.
583                        let old_arity = join.arity();
584                        let on =
585                            on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
586                        join = join.filter(vec![on]);
587                        let new_arity = join.arity();
588                        if old_arity != new_arity {
589                            // This means we added some columns to handle
590                            // subqueries, and now we need to get rid of them.
591                            join = join.project((0..old_arity).collect());
592                        }
593
594                        // If a left join, reintroduce any rows from the left that
595                        // are missing, with nulls filled in for the right columns.
596                        if let JoinKind::LeftOuter { .. } = kind {
597                            let default = join
598                                .typ()
599                                .column_types
600                                .into_iter()
601                                .skip(get_left.arity())
602                                .map(|typ| (Datum::Null, typ.scalar_type))
603                                .collect();
604                            get_left.lookup(id_gen, join, default)
605                        } else {
606                            Ok::<_, PlanError>(join)
607                        }
608                    })?
609                }
610                Join {
611                    left,
612                    right,
613                    on,
614                    kind,
615                } => {
616                    if context.config.enable_variadic_left_join_lowering {
617                        // Attempt to extract a stack of left joins.
618                        if let JoinKind::LeftOuter = kind {
619                            let mut rights = vec![(&*right, &on)];
620                            let mut left_test = &left;
621                            while let Join {
622                                left,
623                                right,
624                                on,
625                                kind: JoinKind::LeftOuter,
626                            } = &**left_test
627                            {
628                                rights.push((&**right, on));
629                                left_test = left;
630                            }
631                            if rights.len() > 1 {
632                                // Defensively clone `cte_map` as it may be mutated.
633                                let cte_map_clone = cte_map.clone();
634                                if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
635                                    left_test,
636                                    rights,
637                                    id_gen,
638                                    get_outer.clone(),
639                                    col_map,
640                                    cte_map,
641                                    context,
642                                ) {
643                                    return Ok(magic);
644                                } else {
645                                    cte_map.clone_from(&cte_map_clone);
646                                }
647                            }
648                        }
649                    }
650
651                    // Both join expressions should be decorrelated, and then joined by their
652                    // leading columns to form only those pairs corresponding to the same row
653                    // of `get_outer`.
654                    //
655                    // The `on` predicate may contain correlated subqueries, and we treat it
656                    // as though it was a filter, with the caveat that we also translate outer
657                    // joins in this step. The post-filtration results need to be considered
658                    // against the records present in the left and right (decorrelated) inputs,
659                    // depending on the type of join.
660                    let oa = get_outer.arity();
661                    let left =
662                        left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
663                    let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
664                    let la = lt.len();
665                    left.let_in(id_gen, |id_gen, get_left| {
666                        let right_col_map = col_map.enter_scope(0);
667                        let right = right.applied_to(
668                            id_gen,
669                            get_outer.clone(),
670                            &right_col_map,
671                            cte_map,
672                            context,
673                        )?;
674                        let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
675                        let ra = rt.len();
676                        right.let_in(id_gen, |id_gen, get_right| {
677                            let mut product = SR::join(
678                                vec![get_left.clone(), get_right.clone()],
679                                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
680                            )
681                            // Project away the repeated copy of get_outer's columns.
682                            .project(
683                                (0..(oa + la))
684                                    .chain((oa + la + oa)..(oa + la + oa + ra))
685                                    .collect(),
686                            );
687
688                            // Decorrelate and lower the `on` clause.
689                            let on = on.applied_to(
690                                id_gen,
691                                col_map,
692                                cte_map,
693                                &mut product,
694                                &None,
695                                context,
696                            )?;
697                            // Collect the types of all subqueries appearing in
698                            // the `on` clause. The subquery results were
699                            // appended to `product` in the `on.applied_to(...)`
700                            // call above.
701                            let on_subquery_types = product
702                                .typ()
703                                .column_types
704                                .drain(oa + la + ra..)
705                                .collect_vec();
706                            // Remember if `on` had any subqueries.
707                            let on_has_subqueries = !on_subquery_types.is_empty();
708
709                            // Attempt an efficient equijoin implementation, in which outer joins are
710                            // more efficiently rendered than in general. This can return `None` if
711                            // such a plan is not possible, for example if `on` does not describe an
712                            // equijoin between columns of `left` and `right`.
713                            if kind != JoinKind::Inner {
714                                if let Some(joined) = attempt_outer_equijoin(
715                                    get_left.clone(),
716                                    get_right.clone(),
717                                    on.clone(),
718                                    on_subquery_types,
719                                    kind.clone(),
720                                    oa,
721                                    id_gen,
722                                    context,
723                                )? {
724                                    if let Some(metrics) = context.metrics {
725                                        metrics.inc_outer_join_lowering("equi");
726                                    }
727                                    return Ok(joined);
728                                }
729                            }
730
731                            // Otherwise, perform a more general join.
732                            if let Some(metrics) = context.metrics {
733                                metrics.inc_outer_join_lowering("general");
734                            }
735                            let mut join = product.filter(vec![on]);
736                            if on_has_subqueries {
737                                // This means that `on.applied_to(...)` appended
738                                // some columns to handle subqueries, and now we
739                                // need to get rid of them.
740                                join = join.project((0..oa + la + ra).collect());
741                            }
742                            join.let_in(id_gen, |id_gen, get_join| {
743                                let mut result = get_join.clone();
744                                if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
745                                    kind
746                                {
747                                    let left_outer = get_left.clone().anti_lookup::<PlanError>(
748                                        id_gen,
749                                        get_join.clone(),
750                                        rt.into_iter()
751                                            .map(|typ| (Datum::Null, typ.scalar_type))
752                                            .collect(),
753                                    )?;
754                                    result = result.union(left_outer);
755                                }
756                                if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
757                                    let right_outer = get_right
758                                        .clone()
759                                        .anti_lookup::<PlanError>(
760                                            id_gen,
761                                            get_join
762                                                // need to swap left and right to make the anti_lookup work
763                                                .project(
764                                                    (0..oa)
765                                                        .chain((oa + la)..(oa + la + ra))
766                                                        .chain((oa)..(oa + la))
767                                                        .collect(),
768                                                ),
769                                            lt.into_iter()
770                                                .map(|typ| (Datum::Null, typ.scalar_type))
771                                                .collect(),
772                                        )?
773                                        // swap left and right back again
774                                        .project(
775                                            (0..oa)
776                                                .chain((oa + ra)..(oa + ra + la))
777                                                .chain((oa)..(oa + ra))
778                                                .collect(),
779                                        );
780                                    result = result.union(right_outer);
781                                }
782                                Ok::<MirRelationExpr, PlanError>(result)
783                            })
784                        })
785                    })?
786                }
787                Union { base, inputs } => {
788                    // Union is uncomplicated.
789                    SR::Union {
790                        base: Box::new(base.applied_to(
791                            id_gen,
792                            get_outer.clone(),
793                            col_map,
794                            cte_map,
795                            context,
796                        )?),
797                        inputs: inputs
798                            .into_iter()
799                            .map(|input| {
800                                input.applied_to(
801                                    id_gen,
802                                    get_outer.clone(),
803                                    col_map,
804                                    cte_map,
805                                    context,
806                                )
807                            })
808                            .collect::<Result<Vec<_>, _>>()?,
809                    }
810                }
811                Reduce {
812                    input,
813                    group_key,
814                    aggregates,
815                    expected_group_size,
816                } => {
817                    // Reduce may contain expressions with correlated subqueries.
818                    // In addition, here an empty reduction key signifies that we need to supply default values
819                    // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
820                    let mut input =
821                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
822                    let applied_group_key = (0..get_outer.arity())
823                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
824                        .collect();
825                    let applied_aggregates = aggregates
826                        .into_iter()
827                        .map(|aggregate| {
828                            aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
829                        })
830                        .collect::<Result<Vec<_>, _>>()?;
831                    let input_type = input.typ();
832                    let default = applied_aggregates
833                        .iter()
834                        .map(|agg| {
835                            (
836                                agg.func.default(),
837                                agg.typ(&input_type.column_types).scalar_type,
838                            )
839                        })
840                        .collect();
841                    // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
842                    let mut reduced =
843                        input.reduce(applied_group_key, applied_aggregates, expected_group_size);
844
845                    // Introduce default values in the case the group key is empty.
846                    if group_key.is_empty() {
847                        reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
848                    }
849                    reduced
850                }
851                Distinct { input } => {
852                    // Distinct is uncomplicated.
853                    input
854                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
855                        .distinct()
856                }
857                TopK {
858                    input,
859                    group_key,
860                    order_key,
861                    limit,
862                    offset,
863                    expected_group_size,
864                } => {
865                    // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
866                    let mut input =
867                        input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
868                    let mut applied_group_key: Vec<_> = (0..get_outer.arity())
869                        .chain(group_key.iter().map(|i| get_outer.arity() + i))
870                        .collect();
871                    let applied_order_key = order_key
872                        .iter()
873                        .map(|column_order| ColumnOrder {
874                            column: column_order.column + get_outer.arity(),
875                            desc: column_order.desc,
876                            nulls_last: column_order.nulls_last,
877                        })
878                        .collect();
879
880                    let old_arity = input.arity();
881
882                    // Lower `limit`, which may introduce new columns if is a correlated subquery.
883                    let mut limit_mir = None;
884                    if let Some(limit) = limit {
885                        limit_mir = Some(
886                            limit
887                                .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
888                        );
889                    }
890
891                    let new_arity = input.arity();
892                    // Extend the key to contain any new columns.
893                    applied_group_key.extend(old_arity..new_arity);
894
895                    let mut result = input.top_k(
896                        applied_group_key,
897                        applied_order_key,
898                        limit_mir,
899                        offset,
900                        expected_group_size,
901                    );
902
903                    // If new columns were added for `limit` we must remove them.
904                    if old_arity != new_arity {
905                        result = result.project((0..old_arity).collect());
906                    }
907
908                    result
909                }
910                Negate { input } => {
911                    // Negate is uncomplicated.
912                    input
913                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
914                        .negate()
915                }
916                Threshold { input } => {
917                    // Threshold is uncomplicated.
918                    input
919                        .applied_to(id_gen, get_outer, col_map, cte_map, context)?
920                        .threshold()
921                }
922            })
923        })
924    }
925}
926
927impl HirScalarExpr {
928    /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
929    ///
930    /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
931    /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
932    /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
933    /// documented in more detail closer to their logic.
934    ///
935    /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
936    /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
937    /// each of these references can be found.
938    fn applied_to(
939        self,
940        id_gen: &mut mz_ore::id_gen::IdGen,
941        col_map: &ColumnMap,
942        cte_map: &mut CteMap,
943        inner: &mut MirRelationExpr,
944        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
945        context: &Context,
946    ) -> Result<MirScalarExpr, PlanError> {
947        maybe_grow(|| {
948            use MirScalarExpr as SS;
949
950            use HirScalarExpr::*;
951
952            if let Some(subquery_map) = subquery_map {
953                if let Some(col) = subquery_map.get(&self) {
954                    return Ok(SS::Column(*col));
955                }
956            }
957
958            Ok::<MirScalarExpr, PlanError>(match self {
959                Column(col_ref) => SS::Column(col_map.get(&col_ref)),
960                Literal(row, typ) => SS::Literal(Ok(row), typ),
961                Parameter(_) => panic!("cannot decorrelate expression with unbound parameters"),
962                CallUnmaterializable(func) => SS::CallUnmaterializable(func),
963                CallUnary { func, expr } => SS::CallUnary {
964                    func,
965                    expr: Box::new(expr.applied_to(
966                        id_gen,
967                        col_map,
968                        cte_map,
969                        inner,
970                        subquery_map,
971                        context,
972                    )?),
973                },
974                CallBinary { func, expr1, expr2 } => SS::CallBinary {
975                    func,
976                    expr1: Box::new(expr1.applied_to(
977                        id_gen,
978                        col_map,
979                        cte_map,
980                        inner,
981                        subquery_map,
982                        context,
983                    )?),
984                    expr2: Box::new(expr2.applied_to(
985                        id_gen,
986                        col_map,
987                        cte_map,
988                        inner,
989                        subquery_map,
990                        context,
991                    )?),
992                },
993                CallVariadic { func, exprs } => SS::CallVariadic {
994                    func,
995                    exprs: exprs
996                        .into_iter()
997                        .map(|expr| {
998                            expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
999                        })
1000                        .collect::<Result<Vec<_>, _>>()?,
1001                },
1002                If { cond, then, els } => {
1003                    // The `If` case is complicated by the fact that we do not want to
1004                    // apply the `then` or `else` logic to tuples that respectively do
1005                    // not or do pass the `cond` test. Our strategy is to independently
1006                    // decorrelate the `then` and `else` logic, and apply each to tuples
1007                    // that respectively pass and do not pass the `cond` logic (which is
1008                    // executed, and so decorrelated, for all tuples).
1009                    //
1010                    // Informally, we turn the `if` statement into:
1011                    //
1012                    //   let then_case = inner.filter(cond).map(then);
1013                    //   let else_case = inner.filter(!cond).map(else);
1014                    //   return then_case.concat(else_case);
1015                    //
1016                    // We only require this if either expression would result in any
1017                    // computation beyond the expr itself, which we will interpret as
1018                    // "introduces additional columns". In the absence of correlation,
1019                    // we should just retain a `ScalarExpr::If` expression; the inverse
1020                    // transformation as above is complicated to recover after the fact,
1021                    // and we would benefit from not introducing the complexity.
1022
1023                    let inner_arity = inner.arity();
1024                    let cond_expr =
1025                        cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1026
1027                    // Defensive copies, in case we mangle these in decorrelation.
1028                    let inner_clone = inner.clone();
1029                    let then_clone = then.clone();
1030                    let else_clone = els.clone();
1031
1032                    let cond_arity = inner.arity();
1033                    let then_expr =
1034                        then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1035                    let else_expr =
1036                        els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1037
1038                    if cond_arity == inner.arity() {
1039                        // If no additional columns were added, we simply return the
1040                        // `If` variant with the updated expressions.
1041                        SS::If {
1042                            cond: Box::new(cond_expr),
1043                            then: Box::new(then_expr),
1044                            els: Box::new(else_expr),
1045                        }
1046                    } else {
1047                        // If columns were added, we need a more careful approach, as
1048                        // described above. First, we need to de-correlate each of
1049                        // the two expressions independently, and apply their cases
1050                        // as `MirRelationExpr::Map` operations.
1051
1052                        *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1053                            // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1054                            let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1055                            let then_expr = then_clone.applied_to(
1056                                id_gen,
1057                                col_map,
1058                                cte_map,
1059                                &mut then_inner,
1060                                subquery_map,
1061                                context,
1062                            )?;
1063                            let then_arity = then_inner.arity();
1064                            then_inner = then_inner
1065                                .map_one(then_expr)
1066                                .project((0..inner_arity).chain(Some(then_arity)).collect());
1067
1068                            // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1069                            let mut else_inner = get_inner.filter(vec![SS::CallVariadic {
1070                                func: mz_expr::VariadicFunc::Or,
1071                                exprs: vec![
1072                                    cond_expr
1073                                        .clone()
1074                                        .call_binary(SS::literal_false(), mz_expr::BinaryFunc::Eq),
1075                                    cond_expr.clone().call_is_null(),
1076                                ],
1077                            }]);
1078                            let else_expr = else_clone.applied_to(
1079                                id_gen,
1080                                col_map,
1081                                cte_map,
1082                                &mut else_inner,
1083                                subquery_map,
1084                                context,
1085                            )?;
1086                            let else_arity = else_inner.arity();
1087                            else_inner = else_inner
1088                                .map_one(else_expr)
1089                                .project((0..inner_arity).chain(Some(else_arity)).collect());
1090
1091                            // concatenate the two results.
1092                            Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1093                        })?;
1094
1095                        SS::Column(inner_arity)
1096                    }
1097                }
1098
1099                // Subqueries!
1100                // These are surprisingly subtle. Things to be careful of:
1101
1102                // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1103                // * change the row counts of the outer query
1104                // * accidentally compute its own value using the row counts of the outer query
1105                // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1106                // query and then join the answers back on to the original rows of the outer query.
1107
1108                // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1109                // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1110                Exists(expr) => {
1111                    let apply_requires_distinct_outer = true;
1112                    *inner = apply_existential_subquery(
1113                        id_gen,
1114                        inner.take_dangerous(),
1115                        col_map,
1116                        cte_map,
1117                        *expr,
1118                        apply_requires_distinct_outer,
1119                        context,
1120                    )?;
1121                    SS::Column(inner.arity() - 1)
1122                }
1123
1124                Select(expr) => {
1125                    let apply_requires_distinct_outer = true;
1126                    *inner = apply_scalar_subquery(
1127                        id_gen,
1128                        inner.take_dangerous(),
1129                        col_map,
1130                        cte_map,
1131                        *expr,
1132                        apply_requires_distinct_outer,
1133                        context,
1134                    )?;
1135                    SS::Column(inner.arity() - 1)
1136                }
1137                Windowing(expr) => {
1138                    let partition_by = expr.partition_by;
1139                    let order_by = expr.order_by;
1140
1141                    // argument lowering for scalar window functions
1142                    // (We need to specify the & _ in the arguments because of this problem:
1143                    // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1144                    let scalar_lower_args =
1145                        |_id_gen: &mut _,
1146                         _col_map: &_,
1147                         _cte_map: &mut _,
1148                         _get_inner: &mut _,
1149                         _subquery_map: &Option<&_>,
1150                         order_by_mir: Vec<MirScalarExpr>,
1151                         original_row_record,
1152                         original_row_record_type: ScalarType| {
1153                            let agg_input = MirScalarExpr::CallVariadic {
1154                                func: mz_expr::VariadicFunc::ListCreate {
1155                                    elem_type: original_row_record_type.clone(),
1156                                },
1157                                exprs: vec![original_row_record],
1158                            };
1159                            let mut agg_input = vec![agg_input];
1160                            agg_input.extend(order_by_mir.clone());
1161                            let agg_input = MirScalarExpr::CallVariadic {
1162                                func: mz_expr::VariadicFunc::RecordCreate {
1163                                    field_names: (0..agg_input.len())
1164                                        .map(|_| ColumnName::from("?column?"))
1165                                        .collect_vec(),
1166                                },
1167                                exprs: agg_input,
1168                            };
1169                            let list_type = ScalarType::List {
1170                                element_type: Box::new(original_row_record_type),
1171                                custom_id: None,
1172                            };
1173                            let agg_input_type = ScalarType::Record {
1174                                fields: std::iter::once(&list_type)
1175                                    .map(|t| {
1176                                        (ColumnName::from("?column?"), t.clone().nullable(false))
1177                                    })
1178                                    .collect(),
1179                                custom_id: None,
1180                            }
1181                            .nullable(false);
1182
1183                            Ok((agg_input, agg_input_type))
1184                        };
1185
1186                    // argument lowering for value window functions and aggregate window functions
1187                    let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1188                        |id_gen: &mut _,
1189                         col_map: &_,
1190                         cte_map: &mut _,
1191                         get_inner: &mut _,
1192                         subquery_map: &Option<&_>,
1193                         order_by_mir: Vec<MirScalarExpr>,
1194                         original_row_record,
1195                         original_row_record_type| {
1196                            // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1197
1198                            // Compute the encoded args for all rows
1199                            let mir_encoded_args = hir_encoded_args.applied_to(
1200                                id_gen,
1201                                col_map,
1202                                cte_map,
1203                                get_inner,
1204                                subquery_map,
1205                                context,
1206                            )?;
1207                            let mir_encoded_args_type = mir_encoded_args
1208                                .typ(&get_inner.typ().column_types)
1209                                .scalar_type;
1210
1211                            // Build a new record that has two fields:
1212                            // 1. the original row in a record
1213                            // 2. the encoded args (which can be either a single value, or a record
1214                            //    if the window function has multiple arguments, such as `lag`)
1215                            let fn_input_record_fields: Box<[_]> =
1216                                [original_row_record_type, mir_encoded_args_type]
1217                                    .iter()
1218                                    .map(|t| {
1219                                        (ColumnName::from("?column?"), t.clone().nullable(false))
1220                                    })
1221                                    .collect();
1222                            let fn_input_record = MirScalarExpr::CallVariadic {
1223                                func: mz_expr::VariadicFunc::RecordCreate {
1224                                    field_names: fn_input_record_fields
1225                                        .iter()
1226                                        .map(|(n, _)| n.clone())
1227                                        .collect_vec(),
1228                                },
1229                                exprs: vec![original_row_record, mir_encoded_args],
1230                            };
1231                            let fn_input_record_type = ScalarType::Record {
1232                                fields: fn_input_record_fields,
1233                                custom_id: None,
1234                            }
1235                            .nullable(false);
1236
1237                            // Build a new record with the record above + the ORDER BY exprs
1238                            // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1239                            let mut agg_input = vec![fn_input_record];
1240                            agg_input.extend(order_by_mir.clone());
1241                            let agg_input = MirScalarExpr::CallVariadic {
1242                                func: mz_expr::VariadicFunc::RecordCreate {
1243                                    field_names: (0..agg_input.len())
1244                                        .map(|_| ColumnName::from("?column?"))
1245                                        .collect_vec(),
1246                                },
1247                                exprs: agg_input,
1248                            };
1249
1250                            let agg_input_type = ScalarType::Record {
1251                                fields: [(
1252                                    ColumnName::from("?column?"),
1253                                    fn_input_record_type.nullable(false),
1254                                )]
1255                                .into(),
1256                                custom_id: None,
1257                            }
1258                            .nullable(false);
1259
1260                            Ok((agg_input, agg_input_type))
1261                        }
1262                    };
1263
1264                    match expr.func {
1265                        WindowExprType::Scalar(scalar_window_expr) => {
1266                            let mir_aggr_func = scalar_window_expr.into_expr();
1267                            Self::window_func_applied_to(
1268                                id_gen,
1269                                col_map,
1270                                cte_map,
1271                                inner,
1272                                subquery_map,
1273                                partition_by,
1274                                order_by,
1275                                mir_aggr_func,
1276                                scalar_lower_args,
1277                                context,
1278                            )?
1279                        }
1280                        WindowExprType::Value(value_window_expr) => {
1281                            let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1282
1283                            Self::window_func_applied_to(
1284                                id_gen,
1285                                col_map,
1286                                cte_map,
1287                                inner,
1288                                subquery_map,
1289                                partition_by,
1290                                order_by,
1291                                mir_aggr_func,
1292                                value_or_aggr_lower_args(hir_encoded_args),
1293                                context,
1294                            )?
1295                        }
1296                        WindowExprType::Aggregate(aggr_window_expr) => {
1297                            let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1298
1299                            Self::window_func_applied_to(
1300                                id_gen,
1301                                col_map,
1302                                cte_map,
1303                                inner,
1304                                subquery_map,
1305                                partition_by,
1306                                order_by,
1307                                mir_aggr_func,
1308                                value_or_aggr_lower_args(hir_encoded_args),
1309                                context,
1310                            )?
1311                        }
1312                    }
1313                }
1314            })
1315        })
1316    }
1317
1318    fn window_func_applied_to<F>(
1319        id_gen: &mut mz_ore::id_gen::IdGen,
1320        col_map: &ColumnMap,
1321        cte_map: &mut CteMap,
1322        inner: &mut MirRelationExpr,
1323        subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1324        partition_by: Vec<HirScalarExpr>,
1325        order_by: Vec<HirScalarExpr>,
1326        mir_aggr_func: AggregateFunc,
1327        lower_args: F,
1328        context: &Context,
1329    ) -> Result<MirScalarExpr, PlanError>
1330    where
1331        F: FnOnce(
1332            &mut mz_ore::id_gen::IdGen,
1333            &ColumnMap,
1334            &mut CteMap,
1335            &mut MirRelationExpr,
1336            &Option<&BTreeMap<HirScalarExpr, usize>>,
1337            Vec<MirScalarExpr>,
1338            MirScalarExpr,
1339            ScalarType,
1340        ) -> Result<(MirScalarExpr, ColumnType), PlanError>,
1341    {
1342        // Example MIRs for a window function (specifically, a window aggregation):
1343        //
1344        // CREATE TABLE t7(x INT, y INT);
1345        //
1346        // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1347        //
1348        // Decorrelated Plan
1349        // Project (#3)
1350        //   Map (#2)
1351        //     Project (#3..=#5)
1352        //       Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1353        //         FlatMap unnest_list(#1)
1354        //           Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1355        //             Map ((#0 + #1))
1356        //               CrossJoin
1357        //                 Constant
1358        //                   - ()
1359        //                 Get materialize.public.t7
1360        //
1361        // The same query after optimizations:
1362        //
1363        // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1364        //
1365        // Optimized Plan
1366        // Explained Query:
1367        //   Project (#2)
1368        //     Map (record_get[0](#1))
1369        //       FlatMap unnest_list(#0)
1370        //         Project (#1)
1371        //           Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1372        //             ReadStorage materialize.public.t7
1373        //
1374        // The `row(row(row(...), ...), ...)` stuff means the following:
1375        // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1376        //   - The <arguments to window function> can be either a single value or itself a
1377        //     `row` if there are multiple arguments.
1378        //   - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1379        //     ORDER BY columns.
1380        //   - The <original row> currently always captures the entire original row. This should
1381        //     improve when we make `ProjectionPushdown` smarter, see
1382        //     https://github.com/MaterializeInc/database-issues/issues/5090
1383        //
1384        // TODO:
1385        // We should probably introduce some dedicated Datum constructor functions instead of `row`
1386        // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1387        // might even introduce dedicated Datum enum variants, so that the rendering code also
1388        // becomes more readable (and possibly slightly more performant).
1389
1390        *inner = inner
1391            .take_dangerous()
1392            .let_in(id_gen, |id_gen, mut get_inner| {
1393                let order_by_mir = order_by
1394                    .into_iter()
1395                    .map(|o| {
1396                        o.applied_to(
1397                            id_gen,
1398                            col_map,
1399                            cte_map,
1400                            &mut get_inner,
1401                            subquery_map,
1402                            context,
1403                        )
1404                    })
1405                    .collect::<Result<Vec<_>, _>>()?;
1406
1407                // Record input arity here so that any group_keys that need to mutate get_inner
1408                // don't add those columns to the aggregate input.
1409                let input_type = get_inner.typ();
1410                let input_arity = input_type.arity();
1411                // The reduction that computes the window function must be keyed on the columns
1412                // from the outer context, plus the expressions in the partition key. The current
1413                // subquery will be 'executed' for every distinct row from the outer context so
1414                // by putting the outer columns in the grouping key we isolate each re-execution.
1415                let mut group_key = col_map
1416                    .inner
1417                    .iter()
1418                    .map(|(_, outer_col)| *outer_col)
1419                    .sorted()
1420                    .collect_vec();
1421                for p in partition_by {
1422                    let key = p.applied_to(
1423                        id_gen,
1424                        col_map,
1425                        cte_map,
1426                        &mut get_inner,
1427                        subquery_map,
1428                        context,
1429                    )?;
1430                    if let MirScalarExpr::Column(c) = key {
1431                        group_key.push(c);
1432                    } else {
1433                        get_inner = get_inner.map_one(key);
1434                        group_key.push(get_inner.arity() - 1);
1435                    }
1436                }
1437
1438                get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1439                    // Original columns of the relation
1440                    let fields: Box<_> = input_type
1441                        .column_types
1442                        .iter()
1443                        .take(input_arity)
1444                        .map(|t| (ColumnName::from("?column?"), t.clone()))
1445                        .collect();
1446
1447                    // Original row made into a record
1448                    let original_row_record = MirScalarExpr::CallVariadic {
1449                        func: mz_expr::VariadicFunc::RecordCreate {
1450                            field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1451                        },
1452                        exprs: (0..input_arity).map(MirScalarExpr::Column).collect_vec(),
1453                    };
1454                    let original_row_record_type = ScalarType::Record {
1455                        fields,
1456                        custom_id: None,
1457                    };
1458
1459                    let (agg_input, agg_input_type) = lower_args(
1460                        id_gen,
1461                        col_map,
1462                        cte_map,
1463                        &mut get_inner,
1464                        subquery_map,
1465                        order_by_mir,
1466                        original_row_record,
1467                        original_row_record_type,
1468                    )?;
1469
1470                    let aggregate = mz_expr::AggregateExpr {
1471                        func: mir_aggr_func,
1472                        expr: agg_input,
1473                        distinct: false,
1474                    };
1475
1476                    // Actually call reduce with the window function
1477                    // The output of the aggregation function should be a list of tuples that has
1478                    // the result in the first position, and the original row in the second position
1479                    let mut reduce = get_inner
1480                        .reduce(group_key.clone(), vec![aggregate.clone()], None)
1481                        .flat_map(
1482                            mz_expr::TableFunc::UnnestList {
1483                                el_typ: aggregate
1484                                    .func
1485                                    .output_type(agg_input_type)
1486                                    .scalar_type
1487                                    .unwrap_list_element_type()
1488                                    .clone(),
1489                            },
1490                            vec![MirScalarExpr::Column(group_key.len())],
1491                        );
1492                    let record_col = reduce.arity() - 1;
1493
1494                    // Unpack the record output by the window function
1495                    for c in 0..input_arity {
1496                        reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1497                            func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1498                            expr: Box::new(MirScalarExpr::CallUnary {
1499                                func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1500                                expr: Box::new(MirScalarExpr::Column(record_col)),
1501                            }),
1502                        });
1503                    }
1504
1505                    // Append the column with the result of the window function.
1506                    reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1507                        func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1508                        expr: Box::new(MirScalarExpr::Column(record_col)),
1509                    });
1510
1511                    let agg_col = record_col + 1 + input_arity;
1512                    Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1513                })
1514            })?;
1515        Ok(MirScalarExpr::Column(inner.arity() - 1))
1516    }
1517
1518    /// Applies the subqueries in the given list of scalar expressions to every distinct
1519    /// value of the given relation and returns a join of the given relation with all
1520    /// the subqueries found, and the mapping of scalar expressions with columns projected
1521    /// by the returned join that will hold their results.
1522    fn lower_subqueries(
1523        exprs: &[Self],
1524        id_gen: &mut mz_ore::id_gen::IdGen,
1525        col_map: &ColumnMap,
1526        cte_map: &mut CteMap,
1527        inner: MirRelationExpr,
1528        context: &Context,
1529    ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1530        let mut subquery_map = BTreeMap::new();
1531        let output = inner.let_in(id_gen, |id_gen, get_inner| {
1532            let mut subqueries = Vec::new();
1533            let distinct_inner = get_inner.clone().distinct();
1534            for expr in exprs.iter() {
1535                expr.visit_pre_post(
1536                    &mut |e| match e {
1537                        // For simplicity, subqueries within a conditional statement will be
1538                        // lowered when lowering the conditional expression.
1539                        HirScalarExpr::If { .. } => Some(vec![]),
1540                        _ => None,
1541                    },
1542                    &mut |e| match e {
1543                        HirScalarExpr::Select(expr) => {
1544                            let apply_requires_distinct_outer = false;
1545                            let subquery = apply_scalar_subquery(
1546                                id_gen,
1547                                distinct_inner.clone(),
1548                                col_map,
1549                                cte_map,
1550                                (**expr).clone(),
1551                                apply_requires_distinct_outer,
1552                                context,
1553                            )
1554                            .unwrap();
1555
1556                            subqueries.push((e.clone(), subquery));
1557                        }
1558                        HirScalarExpr::Exists(expr) => {
1559                            let apply_requires_distinct_outer = false;
1560                            let subquery = apply_existential_subquery(
1561                                id_gen,
1562                                distinct_inner.clone(),
1563                                col_map,
1564                                cte_map,
1565                                (**expr).clone(),
1566                                apply_requires_distinct_outer,
1567                                context,
1568                            )
1569                            .unwrap();
1570                            subqueries.push((e.clone(), subquery));
1571                        }
1572                        _ => {}
1573                    },
1574                )?;
1575            }
1576
1577            if subqueries.is_empty() {
1578                Ok::<MirRelationExpr, PlanError>(get_inner)
1579            } else {
1580                let inner_arity = get_inner.arity();
1581                let mut total_arity = inner_arity;
1582                let mut join_inputs = vec![get_inner];
1583                let mut join_input_arities = vec![inner_arity];
1584                for (expr, subquery) in subqueries.into_iter() {
1585                    // Avoid lowering duplicated subqueries
1586                    if !subquery_map.contains_key(&expr) {
1587                        let subquery_arity = subquery.arity();
1588                        assert_eq!(subquery_arity, inner_arity + 1);
1589                        join_inputs.push(subquery);
1590                        join_input_arities.push(subquery_arity);
1591                        total_arity += subquery_arity;
1592
1593                        // Column with the value of the subquery
1594                        subquery_map.insert(expr, total_arity - 1);
1595                    }
1596                }
1597                // Each subquery projects all the columns of the outer context (distinct_inner)
1598                // plus 1 column, containing the result of the subquery. Those columns must be
1599                // joined with the outer/main relation (get_inner).
1600                let input_mapper =
1601                    mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1602                let equivalences = (0..inner_arity)
1603                    .map(|col| {
1604                        join_inputs
1605                            .iter()
1606                            .enumerate()
1607                            .map(|(input, _)| {
1608                                MirScalarExpr::Column(input_mapper.map_column_to_global(col, input))
1609                            })
1610                            .collect_vec()
1611                    })
1612                    .collect_vec();
1613                Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1614            }
1615        })?;
1616        Ok((output, subquery_map))
1617    }
1618
1619    /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1620    pub fn lower_uncorrelated(self) -> Result<MirScalarExpr, PlanError> {
1621        use MirScalarExpr as SS;
1622
1623        use HirScalarExpr::*;
1624
1625        Ok(match self {
1626            Column(ColumnRef { level: 0, column }) => SS::Column(column),
1627            Literal(datum, typ) => SS::Literal(Ok(datum), typ),
1628            CallUnmaterializable(func) => SS::CallUnmaterializable(func),
1629            CallUnary { func, expr } => SS::CallUnary {
1630                func,
1631                expr: Box::new(expr.lower_uncorrelated()?),
1632            },
1633            CallBinary { func, expr1, expr2 } => SS::CallBinary {
1634                func,
1635                expr1: Box::new(expr1.lower_uncorrelated()?),
1636                expr2: Box::new(expr2.lower_uncorrelated()?),
1637            },
1638            CallVariadic { func, exprs } => SS::CallVariadic {
1639                func,
1640                exprs: exprs
1641                    .into_iter()
1642                    .map(|expr| expr.lower_uncorrelated())
1643                    .collect::<Result<_, _>>()?,
1644            },
1645            If { cond, then, els } => SS::If {
1646                cond: Box::new(cond.lower_uncorrelated()?),
1647                then: Box::new(then.lower_uncorrelated()?),
1648                els: Box::new(els.lower_uncorrelated()?),
1649            },
1650            Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1651                sql_bail!("unexpected ScalarExpr in uncorrelated plan: {:?}", self);
1652            }
1653        })
1654    }
1655}
1656
1657/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1658/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1659/// will, in effect, be executed once for every distinct row in `outer`, and the
1660/// results will be joined with `outer`. Note that columns in `outer` that are
1661/// not depended upon by `inner` are thrown away before the distinct, so that we
1662/// don't perform needless computation of `inner`.
1663///
1664/// `branch` will inspect the contents of `inner` to determine whether `inner`
1665/// is not multiplicity sensitive (roughly, contains only maps, filters,
1666/// projections, and calls to table functions). If it is not multiplicity
1667/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1668/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1669/// operations that were not present in `inner`, then set
1670/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1671/// that distinctifies `outer`.
1672///
1673/// The caller must supply the `apply` function that applies the rewritten
1674/// `inner` to `outer`.
1675fn branch<F>(
1676    id_gen: &mut mz_ore::id_gen::IdGen,
1677    outer: MirRelationExpr,
1678    col_map: &ColumnMap,
1679    cte_map: &mut CteMap,
1680    inner: HirRelationExpr,
1681    apply_requires_distinct_outer: bool,
1682    context: &Context,
1683    apply: F,
1684) -> Result<MirRelationExpr, PlanError>
1685where
1686    F: FnOnce(
1687        &mut mz_ore::id_gen::IdGen,
1688        HirRelationExpr,
1689        MirRelationExpr,
1690        &ColumnMap,
1691        &mut CteMap,
1692        &Context,
1693    ) -> Result<MirRelationExpr, PlanError>,
1694{
1695    // TODO: It would be nice to have a version of this code w/o optimizations,
1696    // at the least for purposes of understanding. It was difficult for one reader
1697    // to understand the required properties of `outer` and `col_map`.
1698
1699    // If the inner expression is sufficiently simple, it is safe to apply it
1700    // *directly* to outer, rather than applying it to the distinctified key
1701    // (see below).
1702    //
1703    // As an example, consider the following two queries:
1704    //
1705    //     CREATE TABLE t (a int, b int);
1706    //     SELECT a, series FROM t, generate_series(1, t.b) series;
1707    //
1708    // The "simple" path for the `SELECT` yields
1709    //
1710    //     %0 =
1711    //     | Get t
1712    //     | FlatMap generate_series(1, #1)
1713    //
1714    // while the non-simple path yields:
1715    //
1716    //    %0 =
1717    //    | Get t
1718    //
1719    //    %1 =
1720    //    | Get t
1721    //    | Distinct group=(#1)
1722    //    | FlatMap generate_series(1, #0)
1723    //
1724    //    %2 =
1725    //    | LeftJoin %1 %2 (= #1 #2)
1726    //
1727    // There is a tradeoff here: the simple plan is stateless, but the non-
1728    // simple plan may do (much) less computation if there are only a few
1729    // distinct values of `t.b`.
1730    //
1731    // We apply a very simple heuristic here and take the simple path if `inner`
1732    // contains only maps, filters, projections, and calls to table functions.
1733    // The intuition is that straightforward usage of table functions should
1734    // take the simple path, while everything else should not. (In theory we
1735    // think this transformation is valid as long as `inner` does not contain a
1736    // Reduce, Distinct, or TopK node, but it is not always an optimization in
1737    // the general case.)
1738    //
1739    // TODO(benesch): this should all be handled by a proper optimizer, but
1740    // detecting the moment of decorrelation in the optimizer right now is too
1741    // hard.
1742    let mut is_simple = true;
1743    #[allow(deprecated)]
1744    inner.visit(0, &mut |expr, _| match expr {
1745        HirRelationExpr::Constant { .. }
1746        | HirRelationExpr::Project { .. }
1747        | HirRelationExpr::Map { .. }
1748        | HirRelationExpr::Filter { .. }
1749        | HirRelationExpr::CallTable { .. } => (),
1750        _ => is_simple = false,
1751    });
1752    if is_simple && !apply_requires_distinct_outer {
1753        let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1754        return outer.let_in(id_gen, |id_gen, get_outer| {
1755            apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1756        });
1757    }
1758
1759    // The key consists of the columns from the outer expression upon which the
1760    // inner relation depends. We discover these dependencies by walking the
1761    // inner relation expression and looking for column references whose level
1762    // escapes inner.
1763    //
1764    // At the end of this process, `key` contains the decorrelated position of
1765    // each outer column, according to the passed-in `col_map`, and
1766    // `new_col_map` maps each outer column to its new ordinal position in key.
1767    let mut outer_cols = BTreeSet::new();
1768    #[allow(deprecated)]
1769    inner.visit_columns(0, &mut |depth, col| {
1770        // Test if the column reference escapes the subquery.
1771        if col.level > depth {
1772            outer_cols.insert(ColumnRef {
1773                level: col.level - depth,
1774                column: col.column,
1775            });
1776        }
1777    });
1778    // Collect all the outer columns referenced by any CTE referenced by
1779    // the inner relation.
1780    #[allow(deprecated)]
1781    inner.visit(0, &mut |e, _| match e {
1782        HirRelationExpr::Get {
1783            id: mz_expr::Id::Local(id),
1784            ..
1785        } => {
1786            if let Some(cte_desc) = cte_map.get(id) {
1787                let cte_outer_arity = cte_desc.outer_relation.arity();
1788                outer_cols.extend(
1789                    col_map
1790                        .inner
1791                        .iter()
1792                        .filter(|(_, position)| **position < cte_outer_arity)
1793                        .map(|(c, _)| {
1794                            // `col_map` maps column references to column positions in
1795                            // `outer`'s projection.
1796                            // `outer_cols` is meant to contain the external column
1797                            // references in `inner`.
1798                            // Since `inner` defines a new scope, any column reference
1799                            // in `col_map` is one level deeper when seen from within
1800                            // `inner`, hence the +1.
1801                            ColumnRef {
1802                                level: c.level + 1,
1803                                column: c.column,
1804                            }
1805                        }),
1806                );
1807            }
1808        }
1809        HirRelationExpr::Let { id, .. } => {
1810            // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1811            // we would need to remove the old CTE with the same ID temporarily while
1812            // traversing the definition of the new CTE under the same ID.
1813            assert!(!cte_map.contains_key(id));
1814        }
1815        _ => {}
1816    });
1817    let mut new_col_map = BTreeMap::new();
1818    let mut key = vec![];
1819    for col in outer_cols {
1820        new_col_map.insert(col, key.len());
1821        key.push(col_map.get(&ColumnRef {
1822            // Note: `outer_cols` contains the external column references within `inner`.
1823            // We must compensate for `inner`'s scope when translating column references
1824            // as seen within `inner` to column references as seen from `outer`'s context,
1825            // hence the -1.
1826            level: col.level - 1,
1827            column: col.column,
1828        }));
1829    }
1830    let new_col_map = ColumnMap::new(new_col_map);
1831    outer.let_in(id_gen, |id_gen, get_outer| {
1832        let keyed_outer = if key.is_empty() {
1833            // Don't depend on outer at all if the branch is not correlated,
1834            // which yields vastly better query plans. Note that this is a bit
1835            // weird in that the branch will be computed even if outer has no
1836            // rows, whereas if it had been correlated it would not (and *could*
1837            // not) have been computed if outer had no rows, but the callers of
1838            // this function don't mind these somewhat-weird semantics.
1839            MirRelationExpr::constant(vec![vec![]], RelationType::new(vec![]))
1840        } else {
1841            get_outer.clone().distinct_by(key.clone())
1842        };
1843        keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1844            let oa = get_outer.arity();
1845            let branch = apply(
1846                id_gen,
1847                inner,
1848                get_keyed_outer,
1849                &new_col_map,
1850                cte_map,
1851                context,
1852            )?;
1853            let ba = branch.arity();
1854            let joined = MirRelationExpr::join(
1855                vec![get_outer.clone(), branch],
1856                key.iter()
1857                    .enumerate()
1858                    .map(|(i, &k)| vec![(0, k), (1, i)])
1859                    .collect(),
1860            )
1861            // throw away the right-hand copy of the key we just joined on
1862            .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1863            Ok(joined)
1864        })
1865    })
1866}
1867
1868fn apply_scalar_subquery(
1869    id_gen: &mut mz_ore::id_gen::IdGen,
1870    outer: MirRelationExpr,
1871    col_map: &ColumnMap,
1872    cte_map: &mut CteMap,
1873    scalar_subquery: HirRelationExpr,
1874    apply_requires_distinct_outer: bool,
1875    context: &Context,
1876) -> Result<MirRelationExpr, PlanError> {
1877    branch(
1878        id_gen,
1879        outer,
1880        col_map,
1881        cte_map,
1882        scalar_subquery,
1883        apply_requires_distinct_outer,
1884        context,
1885        |id_gen, expr, get_inner, col_map, cte_map, context| {
1886            // compute for every row in get_inner
1887            let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1888            let col_type = select.typ().column_types.into_last();
1889
1890            let inner_arity = get_inner.arity();
1891            // We must determine a count for each `get_inner` prefix,
1892            // and report an error if that count exceeds one.
1893            let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1894                // Count for each `get_inner` prefix.
1895                let counts = get_select.clone().reduce(
1896                    (0..inner_arity).collect::<Vec<_>>(),
1897                    vec![mz_expr::AggregateExpr {
1898                        func: mz_expr::AggregateFunc::Count,
1899                        expr: MirScalarExpr::literal_true(),
1900                        distinct: false,
1901                    }],
1902                    None,
1903                );
1904                // Errors should result from counts > 1.
1905                let errors = counts
1906                    .filter(vec![MirScalarExpr::Column(inner_arity).call_binary(
1907                        MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
1908                        mz_expr::BinaryFunc::Gt,
1909                    )])
1910                    .project((0..inner_arity).collect::<Vec<_>>())
1911                    .map_one(MirScalarExpr::literal(
1912                        Err(mz_expr::EvalError::MultipleRowsFromSubquery),
1913                        col_type.clone().scalar_type,
1914                    ));
1915                // Return `get_select` and any errors added in.
1916                Ok::<_, PlanError>(get_select.union(errors))
1917            })?;
1918            // append Null to anything that didn't return any rows
1919            let default = vec![(Datum::Null, col_type.scalar_type)];
1920            get_inner.lookup(id_gen, guarded, default)
1921        },
1922    )
1923}
1924
1925fn apply_existential_subquery(
1926    id_gen: &mut mz_ore::id_gen::IdGen,
1927    outer: MirRelationExpr,
1928    col_map: &ColumnMap,
1929    cte_map: &mut CteMap,
1930    subquery_expr: HirRelationExpr,
1931    apply_requires_distinct_outer: bool,
1932    context: &Context,
1933) -> Result<MirRelationExpr, PlanError> {
1934    branch(
1935        id_gen,
1936        outer,
1937        col_map,
1938        cte_map,
1939        subquery_expr,
1940        apply_requires_distinct_outer,
1941        context,
1942        |id_gen, expr, get_inner, col_map, cte_map, context| {
1943            let exists = expr
1944                // compute for every row in get_inner
1945                .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
1946                // throw away actual values and just remember whether or not there were __any__ rows
1947                .distinct_by((0..get_inner.arity()).collect())
1948                // Append true to anything that returned any rows.
1949                .map(vec![MirScalarExpr::literal_true()]);
1950
1951            // append False to anything that didn't return any rows
1952            get_inner.lookup(id_gen, exists, vec![(Datum::False, ScalarType::Bool)])
1953        },
1954    )
1955}
1956
1957impl AggregateExpr {
1958    fn applied_to(
1959        self,
1960        id_gen: &mut mz_ore::id_gen::IdGen,
1961        col_map: &ColumnMap,
1962        cte_map: &mut CteMap,
1963        inner: &mut MirRelationExpr,
1964        context: &Context,
1965    ) -> Result<mz_expr::AggregateExpr, PlanError> {
1966        let AggregateExpr {
1967            func,
1968            expr,
1969            distinct,
1970        } = self;
1971
1972        Ok(mz_expr::AggregateExpr {
1973            func: func.into_expr(),
1974            expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
1975            distinct,
1976        })
1977    }
1978}
1979
1980/// Attempts an efficient outer join, if `on` has equijoin structure.
1981///
1982/// Both `left` and `right` are decorrelated inputs.
1983///
1984/// The first `oa` columns correspond to an outer context: we should do the
1985/// outer join independently for each prefix. In the case that `on` contains
1986/// just some equality tests between columns of `left` and `right` and some
1987/// local predicates, we can employ a relatively simple plan.
1988///
1989/// The last `on_subquery_types.len()` columns correspond to results from
1990/// subqueries defined in the `on` clause - we treat those as theta-join
1991/// conditions that prohibit the use of the simple plan attempted here.
1992fn attempt_outer_equijoin(
1993    left: MirRelationExpr,
1994    right: MirRelationExpr,
1995    on: MirScalarExpr,
1996    on_subquery_types: Vec<ColumnType>,
1997    kind: JoinKind,
1998    oa: usize,
1999    id_gen: &mut mz_ore::id_gen::IdGen,
2000    context: &Context,
2001) -> Result<Option<MirRelationExpr>, PlanError> {
2002    // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2003    // predicates that reference subqueries as long as these subqueries don't
2004    // reference `left` and `right` at the same time.
2005    //
2006    // TODO(database-issues#6828): This code can be improved as follows:
2007    //
2008    // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2009    // 2. Use the canonicalized `on` predicate in the non-equijoin based
2010    //    lowering strategy.
2011    // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2012    // 4. Pass the classified `OnPredicates` as a parameter.
2013    // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2014    //
2015    // Steps (1 + 2) require further investigation because we might change the
2016    // error semantics in case the `on` predicate contains a literal error..
2017
2018    let l_type = left.typ();
2019    let r_type = right.typ();
2020    let la = l_type.column_types.len() - oa;
2021    let ra = r_type.column_types.len() - oa;
2022    let sa = on_subquery_types.len();
2023
2024    // The output type contains [outer, left, right, sa] attributes.
2025    let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2026    output_type.extend(l_type.column_types);
2027    output_type.extend(r_type.column_types.into_iter().skip(oa));
2028    output_type.extend(on_subquery_types);
2029
2030    // Generally healthy to do, but specifically `USING` conditions sometimes
2031    // put an `AND true` at the end of the `ON` condition.
2032    //
2033    // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2034    // However, in that case it's not clear that we won't see regressions if
2035    // `on` simplifies to a literal error.
2036    let mut on = vec![on];
2037    mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2038
2039    // Form the left and right types without the outer attributes.
2040    output_type.drain(0..oa);
2041    let lt = output_type.drain(0..la).collect_vec();
2042    let rt = output_type.drain(0..ra).collect_vec();
2043    assert!(output_type.len() == sa);
2044
2045    let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2046    if !on_predicates.is_equijoin(context) {
2047        return Ok(None);
2048    }
2049
2050    // If we've gotten this far, we can do the clever thing.
2051    // We'll want to use left and right multiple times
2052    let result = left.let_in(id_gen, |id_gen, get_left| {
2053        right.let_in(id_gen, |id_gen, get_right| {
2054            // TODO: we know that we can re-use the arrangements of left and right
2055            // needed for the inner join with each of the conditional outer joins.
2056            // It is not clear whether we should hint that, or just let the planner
2057            // and optimizer run and see what happens.
2058
2059            // We'll want the inner join (minus repeated columns)
2060            let join = MirRelationExpr::join(
2061                vec![get_left.clone(), get_right.clone()],
2062                (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2063            )
2064            // remove those columns from `right` repeating the first `oa` columns.
2065            .project(
2066                (0..(oa + la))
2067                    .chain((oa + la + oa)..(oa + la + oa + ra))
2068                    .collect(),
2069            )
2070            // apply the filter constraints here, to ensure nulls are not matched.
2071            .filter(on);
2072
2073            // We'll want to re-use the results of the join multiple times.
2074            join.let_in(id_gen, |id_gen, get_join| {
2075                let mut result = get_join.clone();
2076
2077                // A collection of keys present in both left and right collections.
2078                let join_keys = on_predicates.join_keys();
2079                let both_keys_arity = join_keys.len();
2080                let both_keys = get_join.restrict(join_keys).distinct();
2081
2082                // The plan is now to determine the left and right rows matched in the
2083                // inner join, subtract them from left and right respectively, pad what
2084                // remains with nulls, and fold them in to `result`.
2085
2086                both_keys.let_in(id_gen, |_id_gen, get_both| {
2087                    if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2088                        // Rows in `left` matched in the inner equijoin. This is
2089                        // a semi-join between `left` and `both_keys`.
2090                        let left_present = MirRelationExpr::join_scalars(
2091                            vec![
2092                                get_left
2093                                    .clone()
2094                                    // Push local predicates.
2095                                    .filter(on_predicates.lhs()),
2096                                get_both.clone(),
2097                            ],
2098                            itertools::zip_eq(
2099                                on_predicates.eq_lhs(),
2100                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2101                            )
2102                            .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2103                            .collect(),
2104                        )
2105                        .project((0..(oa + la)).collect());
2106
2107                        // Determine the types of nulls to use as filler.
2108                        let right_fill = rt
2109                            .into_iter()
2110                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2111                            .collect();
2112
2113                        // Add to `result` absent elements, filled with typed nulls.
2114                        result = left_present
2115                            .negate()
2116                            .union(get_left.clone())
2117                            .map(right_fill)
2118                            .union(result);
2119                    }
2120
2121                    if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2122                        // Rows in `right` matched in the inner equijoin. This
2123                        // is a semi-join between `right` and `both_keys`.
2124                        let right_present = MirRelationExpr::join_scalars(
2125                            vec![
2126                                get_right
2127                                    .clone()
2128                                    // Push local predicates.
2129                                    .filter(on_predicates.rhs()),
2130                                get_both,
2131                            ],
2132                            itertools::zip_eq(
2133                                on_predicates.eq_rhs(),
2134                                (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2135                            )
2136                            .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2137                            .collect(),
2138                        )
2139                        .project((0..(oa + ra)).collect());
2140
2141                        // Determine the types of nulls to use as filler.
2142                        let left_fill = lt
2143                            .into_iter()
2144                            .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2145                            .collect();
2146
2147                        // Add to `result` absent elements, prepended with typed nulls.
2148                        result = right_present
2149                            .negate()
2150                            .union(get_right.clone())
2151                            .map(left_fill)
2152                            // Permute left fill before right values.
2153                            .project(
2154                                itertools::chain!(
2155                                    0..oa,                 // Preserve `outer`.
2156                                    oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2157                                    oa..oa + ra            // Decrement the next `ra` cols by `la`.
2158                                )
2159                                .collect(),
2160                            )
2161                            .union(result)
2162                    }
2163
2164                    Ok::<_, PlanError>(result)
2165                })
2166            })
2167        })
2168    })?;
2169    Ok(Some(result))
2170}
2171
2172/// A struct that represents the predicates in the `on` clause in a form
2173/// suitable for efficient planning outer joins with equijoin predicates.
2174struct OnPredicates {
2175    /// A store for classified `ON` predicates.
2176    ///
2177    /// Predicates that reference a single side are adjusted to assume an
2178    /// `outer × <side>` schema.
2179    predicates: Vec<OnPredicate>,
2180    /// Number of outer context columns.
2181    oa: usize,
2182}
2183
2184impl OnPredicates {
2185    const I_OUT: usize = 0; // outer context input position
2186    const I_LHS: usize = 1; // lhs input position
2187    const I_RHS: usize = 2; // rhs input position
2188    const I_SUB: usize = 3; // on subqueries input position
2189
2190    /// Classify the predicates in the `on` clause of an outer join.
2191    ///
2192    /// The other parameters are arities of the input parts:
2193    ///
2194    /// - `oa` is the arity of the `outer` context.
2195    /// - `la` is the arity of the `left` input.
2196    /// - `ra` is the arity of the `right` input.
2197    /// - `sa` is the arity of the `on` subqueries.
2198    ///
2199    /// The constructor assumes that:
2200    ///
2201    /// 1. The `on` parameter will be applied on a result that has the following
2202    ///    schema `outer × left × right × on_subqueries`.
2203    /// 2. The `on` parameter is already adjusted to assume that schema.
2204    /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2205    ///    MirScalarExpr` with `canonicalize_predicates`.
2206    fn new(
2207        oa: usize,
2208        la: usize,
2209        ra: usize,
2210        sa: usize,
2211        on: Vec<MirScalarExpr>,
2212        _context: &Context,
2213    ) -> Self {
2214        use mz_expr::BinaryFunc::Eq;
2215
2216        // Re-bind those locally for more compact pattern matching.
2217        const I_LHS: usize = OnPredicates::I_LHS;
2218        const I_RHS: usize = OnPredicates::I_RHS;
2219
2220        // Self parameters.
2221        let mut predicates = Vec::with_capacity(on.len());
2222
2223        // Helpers for populating `predicates`.
2224        let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2225        let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2226        let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2227            inner_join_mapper
2228                .lookup_inputs(expr)
2229                .filter(|&i| i != Self::I_OUT)
2230                .collect()
2231        };
2232        let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2233            inner_join_mapper
2234                .lookup_inputs(expr)
2235                .any(|i| i == Self::I_SUB)
2236        };
2237
2238        // Iterate over `on` elements and populate `predicates`.
2239        for mut predicate in on {
2240            if predicate.might_error() {
2241                tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2242                // Treat predicates that can produce a literal error as Theta.
2243                predicates.push(OnPredicate::Theta(predicate));
2244            } else if has_subquery_refs(&predicate) {
2245                tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2246                // Treat predicates referencing an `on` subquery as Theta.
2247                predicates.push(OnPredicate::Theta(predicate));
2248            } else if let MirScalarExpr::CallBinary {
2249                func: Eq,
2250                expr1,
2251                expr2,
2252            } = &mut predicate
2253            {
2254                // Obtain the non-outer inputs referenced by each side.
2255                let inputs1 = lookup_inputs(expr1);
2256                let inputs2 = lookup_inputs(expr2);
2257
2258                match (&inputs1[..], &inputs2[..]) {
2259                    // Neither side references an input. This could be a
2260                    // constant expression or an expression that depends only on
2261                    // the outer context.
2262                    ([], []) => {
2263                        predicates.push(OnPredicate::Const(predicate));
2264                    }
2265                    // Both sides reference different inputs.
2266                    ([I_LHS], [I_RHS]) => {
2267                        let lhs = expr1.take();
2268                        let mut rhs = expr2.take();
2269                        rhs.permute(&rhs_permutation);
2270                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2271                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2272                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2273                    }
2274                    // Both sides reference different inputs (swapped).
2275                    ([I_RHS], [I_LHS]) => {
2276                        let lhs = expr2.take();
2277                        let mut rhs = expr1.take();
2278                        rhs.permute(&rhs_permutation);
2279                        predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2280                        predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2281                        predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2282                    }
2283                    // Both sides reference the left input or no input.
2284                    ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2285                        predicates.push(OnPredicate::Lhs(predicate));
2286                    }
2287                    // Both sides reference the right input or no input.
2288                    ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2289                        predicate.permute(&rhs_permutation);
2290                        predicates.push(OnPredicate::Rhs(predicate));
2291                    }
2292                    // At least one side references more than one input.
2293                    _ => {
2294                        tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2295                        predicates.push(OnPredicate::Theta(predicate));
2296                    }
2297                }
2298            } else {
2299                // Obtain the non-outer inputs referenced by this predicate.
2300                let inputs = lookup_inputs(&predicate);
2301
2302                match &inputs[..] {
2303                    // The predicate references no inputs. This could be a
2304                    // constant expression or an expression that depends only on
2305                    // the outer context.
2306                    [] => {
2307                        predicates.push(OnPredicate::Const(predicate));
2308                    }
2309                    // The predicate references only the left input.
2310                    [I_LHS] => {
2311                        predicates.push(OnPredicate::Lhs(predicate));
2312                    }
2313                    // The predicate references only the right input.
2314                    [I_RHS] => {
2315                        predicate.permute(&rhs_permutation);
2316                        predicates.push(OnPredicate::Rhs(predicate));
2317                    }
2318                    // The predicate references both inputs.
2319                    _ => {
2320                        tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2321                        predicates.push(OnPredicate::Theta(predicate));
2322                    }
2323                }
2324            }
2325        }
2326
2327        Self { predicates, oa }
2328    }
2329
2330    /// Check if the predicates can be lowered with an equijoin-based strategy.
2331    fn is_equijoin(&self, context: &Context) -> bool {
2332        // Count each `OnPredicate` variant in `self.predicates`.
2333        let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2334            self.predicates.iter().fold(
2335                (0, 0, 0, 0, 0, 0),
2336                |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2337                    (
2338                        const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2339                        lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2340                        rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2341                        eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2342                        eq_cols + usize::from(matches!(p, OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column())),
2343                        theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2344                    )
2345                },
2346            );
2347
2348        let is_equijion = if context.config.enable_new_outer_join_lowering {
2349            // New classifier.
2350            eq_cnt > 0 && theta_cnt == 0
2351        } else {
2352            // Old classifier.
2353            eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2354        };
2355
2356        // Log an entry only if this is an equijoin according to the new classifier.
2357        if eq_cnt > 0 && theta_cnt == 0 {
2358            tracing::debug!(
2359                const_cnt,
2360                lhs_cnt,
2361                rhs_cnt,
2362                eq_cnt,
2363                eq_cols,
2364                theta_cnt,
2365                "OnPredicates::is_equijoin"
2366            );
2367        }
2368
2369        is_equijion
2370    }
2371
2372    /// Return an [`MirRelationExpr`] list that represents the keys for the
2373    /// equijoin. The list will contain the outer columns as a prefix.
2374    fn join_keys(&self) -> JoinKeys {
2375        // We could return either the `lhs` or the `rhs` of the keys used to
2376        // form the inner join as they are equated by the join condition.
2377        let join_keys = self.eq_lhs().collect::<Vec<_>>();
2378
2379        if join_keys.iter().all(|k| k.is_column()) {
2380            tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2381            JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2382        } else {
2383            tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2384            JoinKeys::Scalars(join_keys)
2385        }
2386    }
2387
2388    /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2389    /// conditions in the predicates list.
2390    ///
2391    /// The iterator will start with column references to the outer columns as a
2392    /// prefix.
2393    fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2394        itertools::chain(
2395            (0..self.oa).map(MirScalarExpr::column),
2396            self.predicates.iter().filter_map(|e| match e {
2397                OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2398                _ => None,
2399            }),
2400        )
2401    }
2402
2403    /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2404    /// conditions in the predicates list.
2405    ///
2406    /// The iterator will start with column references to the outer columns as a
2407    /// prefix.
2408    fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2409        itertools::chain(
2410            (0..self.oa).map(MirScalarExpr::column),
2411            self.predicates.iter().filter_map(|e| match e {
2412                OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2413                _ => None,
2414            }),
2415        )
2416    }
2417
2418    /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2419    /// [`OnPredicate::Const`] conditions in the predicates list.
2420    fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2421        self.predicates.iter().filter_map(|p| match p {
2422            // We treat Const predicates local to both inputs.
2423            OnPredicate::Const(p) => Some(p.clone()),
2424            OnPredicate::Lhs(p) => Some(p.clone()),
2425            OnPredicate::LhsConsequence(p) => Some(p.clone()),
2426            _ => None,
2427        })
2428    }
2429
2430    /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2431    /// [`OnPredicate::Const`] conditions in the predicates list.
2432    fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2433        self.predicates.iter().filter_map(|p| match p {
2434            // We treat Const predicates local to both inputs.
2435            OnPredicate::Const(p) => Some(p.clone()),
2436            OnPredicate::Rhs(p) => Some(p.clone()),
2437            OnPredicate::RhsConsequence(p) => Some(p.clone()),
2438            _ => None,
2439        })
2440    }
2441}
2442
2443enum OnPredicate {
2444    // A predicate that is either constant or references only outer columns.
2445    Const(MirScalarExpr),
2446    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2447    // and possibly outer columns.
2448    //
2449    // This is one of the original predicates from the ON clause.
2450    //
2451    // One _must_ apply this predicate.
2452    Lhs(MirScalarExpr),
2453    // A local predicate on the left-hand side of the join, i.e., it references only the left input
2454    // and possibly outer columns.
2455    //
2456    // This is not one of the original predicates from the ON clause, but is just a consequence
2457    // of an original predicate in the ON clause, where the original predicate references both
2458    // inputs, but the consequence references only the left input.
2459    //
2460    // For example, the original predicate `input1.x = input2.a` has the consequence
2461    // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2462    // prevents null skew, and also makes more CSE opportunities available when the left input's key
2463    // doesn't have a NOT NULL constraint, saving us an arrangement.
2464    //
2465    // Applying the predicate is optional, because the original predicate will be applied anyway.
2466    LhsConsequence(MirScalarExpr),
2467    // A local predicate on the right-hand side of the join.
2468    //
2469    // This is one of the original predicates from the ON clause.
2470    //
2471    // One _must_ apply this predicate.
2472    Rhs(MirScalarExpr),
2473    // A consequence of an original ON predicate, see above.
2474    RhsConsequence(MirScalarExpr),
2475    // An equality predicate between the two sides.
2476    Eq(MirScalarExpr, MirScalarExpr),
2477    // a non-equality predicate between the two sides.
2478    #[allow(dead_code)]
2479    Theta(MirScalarExpr),
2480}
2481
2482/// A set of join keys referencing an input.
2483///
2484/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2485/// avoid changes (and thereby possible regressions) in plans that have equijoin
2486/// predicates consisting only of column refs.
2487///
2488/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2489/// able to get rid of this code, but as it stands `Map` simplification seems
2490/// more cumbersome than `Project` simplification, so do this just to be sure.
2491enum JoinKeys {
2492    // A predicate that is either constant or references only outer columns.
2493    Outputs(Vec<usize>),
2494    // A local predicate on the left-hand side of the join.
2495    Scalars(Vec<MirScalarExpr>),
2496}
2497
2498impl JoinKeys {
2499    fn len(&self) -> usize {
2500        match self {
2501            JoinKeys::Outputs(outputs) => outputs.len(),
2502            JoinKeys::Scalars(scalars) => scalars.len(),
2503        }
2504    }
2505}
2506
2507/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2508/// code.
2509trait LoweringExt {
2510    /// See [`MirRelationExpr::restrict`].
2511    fn restrict(self, join_keys: JoinKeys) -> Self;
2512}
2513
2514impl LoweringExt for MirRelationExpr {
2515    /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2516    fn restrict(self, join_keys: JoinKeys) -> Self {
2517        let num_keys = join_keys.len();
2518        match join_keys {
2519            JoinKeys::Outputs(outputs) => self.project(outputs),
2520            JoinKeys::Scalars(scalars) => {
2521                let input_arity = self.arity();
2522                let outputs = (input_arity..input_arity + num_keys).collect();
2523                self.map(scalars).project(outputs)
2524            }
2525        }
2526    }
2527}