mz_sql/plan/lowering.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Lowering is the process of transforming a `HirRelationExpr`
11//! into a `MirRelationExpr`.
12//!
13//! The most crucial part of lowering is decorrelation; i.e.: rewriting a
14//! `HirScalarExpr` that may contain subqueries (e.g. `SELECT` or `EXISTS`)
15//! with instances of `MirScalarExpr` that contain none of these.
16//!
17//! Informally, a subquery should be viewed as a query that is executed in
18//! the context of some outer relation, for each row of that relation. The
19//! subqueries often contain references to the columns of the outer
20//! relation.
21//!
22//! The transformation we perform maintains an `outer` relation and then
23//! traverses the relation expression that may contain references to those
24//! outer columns. As subqueries are discovered, the current relation
25//! expression is recast as the outer expression until such a point as the
26//! scalar expression's evaluation can be determined and appended to each
27//! row of the previously outer relation.
28//!
29//! It is important that the outer columns (the initial columns) act as keys
30//! for all nested computation. When counts or other aggregations are
31//! performed, they should include not only the indicated keys but also all
32//! of the outer columns.
33//!
34//! The decorrelation transformation is initialized with an empty outer
35//! relation, but it seems entirely appropriate to decorrelate queries that
36//! contain "holes" from prepared statements, as if the query was a subquery
37//! against a relation containing the assignments of values to those holes.
38
39use std::collections::{BTreeMap, BTreeSet};
40use std::iter::repeat;
41
42use itertools::Itertools;
43use mz_expr::visit::Visit;
44use mz_expr::{AccessStrategy, AggregateFunc, MirRelationExpr, MirScalarExpr, func};
45use mz_ore::collections::CollectionExt;
46use mz_ore::stack::maybe_grow;
47use mz_repr::*;
48
49use crate::optimizer_metrics::OptimizerMetrics;
50use crate::plan::hir::{
51 AggregateExpr, ColumnOrder, ColumnRef, HirRelationExpr, HirScalarExpr, JoinKind, WindowExprType,
52};
53use crate::plan::{PlanError, transform_hir};
54use crate::session::vars::SystemVars;
55
56mod variadic_left;
57
58/// Maps a leveled column reference to a specific column.
59///
60/// Leveled column references are nested, so that larger levels are
61/// found early in a record and level zero is found at the end.
62///
63/// The column map only stores references for levels greater than zero,
64/// and column references at level zero simply start at the first column
65/// after all prior references.
66#[derive(Debug, Clone)]
67struct ColumnMap {
68 inner: BTreeMap<ColumnRef, usize>,
69}
70
71impl ColumnMap {
72 fn empty() -> ColumnMap {
73 Self::new(BTreeMap::new())
74 }
75
76 fn new(inner: BTreeMap<ColumnRef, usize>) -> ColumnMap {
77 ColumnMap { inner }
78 }
79
80 fn get(&self, col_ref: &ColumnRef) -> usize {
81 if col_ref.level == 0 {
82 self.inner.len() + col_ref.column
83 } else {
84 self.inner[col_ref]
85 }
86 }
87
88 fn len(&self) -> usize {
89 self.inner.len()
90 }
91
92 /// Updates references in the `ColumnMap` for use in a nested scope. The
93 /// provided `arity` must specify the arity of the current scope.
94 fn enter_scope(&self, arity: usize) -> ColumnMap {
95 // From the perspective of the nested scope, all existing column
96 // references will be one level greater.
97 let existing = self
98 .inner
99 .clone()
100 .into_iter()
101 .update(|(col, _i)| col.level += 1);
102
103 // All columns in the current scope become explicit entries in the
104 // immediate parent scope.
105 let new = (0..arity).map(|i| {
106 (
107 ColumnRef {
108 level: 1,
109 column: i,
110 },
111 self.len() + i,
112 )
113 });
114
115 ColumnMap::new(existing.chain(new).collect())
116 }
117}
118
119/// Map with the CTEs currently in scope.
120type CteMap = BTreeMap<mz_expr::LocalId, CteDesc>;
121
122/// Information about needed when finding a reference to a CTE in scope.
123#[derive(Clone)]
124struct CteDesc {
125 /// The new ID assigned to the lowered version of the CTE, which may not match
126 /// the ID of the input CTE.
127 new_id: mz_expr::LocalId,
128 /// The relation type of the CTE including the columns from the outer
129 /// context at the beginning.
130 relation_type: SqlRelationType,
131 /// The outer relation the CTE was applied to.
132 outer_relation: MirRelationExpr,
133}
134
135#[derive(Debug, Clone, Copy)]
136pub struct Config {
137 /// Enable outer join lowering implemented in database-issues#6747.
138 pub enable_new_outer_join_lowering: bool,
139 /// Enable outer join lowering implemented in database-issues#7561.
140 pub enable_variadic_left_join_lowering: bool,
141 pub enable_guard_subquery_tablefunc: bool,
142 pub enable_cast_elimination: bool,
143}
144
145impl Default for Config {
146 fn default() -> Self {
147 Self {
148 enable_new_outer_join_lowering: false,
149 enable_variadic_left_join_lowering: false,
150 enable_guard_subquery_tablefunc: false,
151 enable_cast_elimination: false,
152 }
153 }
154}
155
156impl From<&SystemVars> for Config {
157 fn from(vars: &SystemVars) -> Self {
158 Self {
159 enable_new_outer_join_lowering: vars.enable_new_outer_join_lowering(),
160 enable_variadic_left_join_lowering: vars.enable_variadic_left_join_lowering(),
161 enable_guard_subquery_tablefunc: vars.enable_guard_subquery_tablefunc(),
162 enable_cast_elimination: vars.enable_cast_elimination(),
163 }
164 }
165}
166
167/// Context passed to the lowering. This is wired to most parts of the lowering.
168pub(crate) struct Context<'a> {
169 /// Feature flags affecting the behavior of lowering.
170 pub config: &'a Config,
171 /// Optional, because some callers don't have an `OptimizerMetrics` handy. When it's None, we
172 /// simply don't write metrics.
173 pub metrics: Option<&'a OptimizerMetrics>,
174}
175
176impl HirRelationExpr {
177 /// Rewrite `self` into a `MirRelationExpr`.
178 /// This requires rewriting all correlated subqueries (nested `HirRelationExpr`s) into flat queries
179 #[mz_ore::instrument(target = "optimizer", level = "trace", name = "hir_to_mir")]
180 pub fn lower<C: Into<Config>>(
181 self,
182 config: C,
183 metrics: Option<&OptimizerMetrics>,
184 ) -> Result<MirRelationExpr, PlanError> {
185 let context = Context {
186 config: &config.into(),
187 metrics,
188 };
189 let result = match self {
190 // We directly rewrite a Constant into the corresponding `MirRelationExpr::Constant`
191 // to ensure that the downstream optimizer can easily bypass most
192 // irrelevant optimizations (e.g. reduce folding) for this expression
193 // without having to re-learn the fact that it is just a constant,
194 // as it would if the constant were wrapped in a Let-Get pair.
195 HirRelationExpr::Constant { rows, typ } => {
196 let rows: Vec<_> = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
197 MirRelationExpr::Constant {
198 rows: Ok(rows),
199 typ,
200 }
201 }
202 mut other => {
203 let mut id_gen = mz_ore::id_gen::IdGen::default();
204 transform_hir::split_subquery_predicates(&mut other)?;
205 transform_hir::try_simplify_quantified_comparisons(&mut other)?;
206 transform_hir::fuse_window_functions(&mut other, &context)?;
207 MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![])).let_in(
208 &mut id_gen,
209 |id_gen, get_outer| {
210 other.applied_to(
211 id_gen,
212 get_outer,
213 &ColumnMap::empty(),
214 &mut CteMap::new(),
215 &context,
216 )
217 },
218 )?
219 }
220 };
221
222 mz_repr::explain::trace_plan(&result);
223
224 Ok(result)
225 }
226
227 /// Return a `MirRelationExpr` which evaluates `self` once for each row of `get_outer`.
228 ///
229 /// For uncorrelated `self`, this should be the cross-product between `get_outer` and `self`.
230 /// When `self` references columns of `get_outer`, much more work needs to occur.
231 ///
232 /// The `col_map` argument contains mappings to some of the columns of `get_outer`, though
233 /// perhaps not all of them. It should be used as the basis of resolving column references,
234 /// but care must be taken when adding new columns that `get_outer.arity()` is where they
235 /// will start, rather than any function of `col_map`.
236 ///
237 /// The `get_outer` expression should be a `Get` with no duplicate rows, describing the distinct
238 /// assignment of values to outer rows.
239 fn applied_to(
240 self,
241 id_gen: &mut mz_ore::id_gen::IdGen,
242 get_outer: MirRelationExpr,
243 col_map: &ColumnMap,
244 cte_map: &mut CteMap,
245 context: &Context,
246 ) -> Result<MirRelationExpr, PlanError> {
247 maybe_grow(|| {
248 use MirRelationExpr as SR;
249
250 use HirRelationExpr::*;
251
252 if let MirRelationExpr::Get { .. } = &get_outer {
253 } else {
254 panic!(
255 "get_outer: expected a MirRelationExpr::Get, found\n{}",
256 get_outer.pretty(),
257 );
258 }
259 assert_eq!(col_map.len(), get_outer.arity());
260 Ok(match self {
261 Constant { rows, typ } => {
262 // Constant expressions are not correlated with `get_outer`, and should be cross-products.
263 get_outer.product(SR::Constant {
264 rows: Ok(rows.into_iter().map(|row| (row, Diff::ONE)).collect()),
265 typ,
266 })
267 }
268 Get { id, typ } => match id {
269 mz_expr::Id::Local(local_id) => {
270 let cte_desc = cte_map.get(&local_id).unwrap();
271 let get_cte = SR::Get {
272 id: mz_expr::Id::Local(cte_desc.new_id.clone()),
273 typ: cte_desc.relation_type.clone(),
274 access_strategy: AccessStrategy::UnknownOrLocal,
275 };
276 if get_outer == cte_desc.outer_relation {
277 // If the CTE was applied to the same exact relation, we can safely
278 // return a `Get` relation.
279 get_cte
280 } else {
281 // Otherwise, the new outer relation may contain more columns from some
282 // intermediate scope placed between the definition of the CTE and this
283 // reference of the CTE and/or more operations applied on top of the
284 // outer relation.
285 //
286 // An example of the latter is the following query:
287 //
288 // SELECT *
289 // FROM x,
290 // LATERAL(WITH a(m) as (SELECT max(y.a) FROM y WHERE y.a < x.a)
291 // SELECT (SELECT m FROM a) FROM y) b;
292 //
293 // When the CTE is lowered, the outer relation is `Get x`. But then,
294 // the reference of the CTE is applied to `Distinct(Join(Get x, Get y), x.*)`
295 // which has the same cardinality as `Get x`.
296 //
297 // In any case, `get_outer` is guaranteed to contain the columns of the
298 // outer relation the CTE was applied to at its prefix. Since, we must
299 // return a relation containing `get_outer`'s column at the beginning,
300 // we must build a join between `get_outer` and `get_cte` on their common
301 // columns.
302 let oa = get_outer.arity();
303 let cte_outer_columns = cte_desc.relation_type.arity() - typ.arity();
304 let equivalences = (0..cte_outer_columns)
305 .map(|pos| {
306 vec![
307 MirScalarExpr::column(pos),
308 MirScalarExpr::column(pos + oa),
309 ]
310 })
311 .collect();
312
313 // Project out the second copy of the common between `get_outer` and
314 // `cte_desc.outer_relation`.
315 let projection = (0..oa)
316 .chain(oa + cte_outer_columns..oa + cte_outer_columns + typ.arity())
317 .collect_vec();
318 SR::join_scalars(vec![get_outer, get_cte], equivalences)
319 .project(projection)
320 }
321 }
322 _ => {
323 // Get statements are only to external sources, and are not correlated with `get_outer`.
324 get_outer.product(SR::Get {
325 id,
326 typ,
327 access_strategy: AccessStrategy::UnknownOrLocal,
328 })
329 }
330 },
331 Let {
332 name: _,
333 id,
334 value,
335 body,
336 } => {
337 let value =
338 value.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
339 value.let_in(id_gen, |id_gen, get_value| {
340 let (new_id, typ) = if let MirRelationExpr::Get {
341 id: mz_expr::Id::Local(id),
342 typ,
343 ..
344 } = get_value
345 {
346 (id, typ)
347 } else {
348 panic!(
349 "get_value: expected a MirRelationExpr::Get with local Id, found\n{}",
350 get_value.pretty(),
351 );
352 };
353 // Add the information about the CTE to the map and remove it when
354 // it goes out of scope.
355 let old_value = cte_map.insert(
356 id.clone(),
357 CteDesc {
358 new_id,
359 relation_type: typ,
360 outer_relation: get_outer.clone(),
361 },
362 );
363 let body = body.applied_to(id_gen, get_outer, col_map, cte_map, context);
364 if let Some(old_value) = old_value {
365 cte_map.insert(id, old_value);
366 } else {
367 cte_map.remove(&id);
368 }
369 body
370 })?
371 }
372 LetRec {
373 limit,
374 bindings,
375 body,
376 } => {
377 let num_bindings = bindings.len();
378
379 // We use the outer type with the HIR types to form MIR CTE types.
380 let outer_column_types = get_outer.typ().column_types;
381
382 // Rename and introduce all bindings.
383 let mut shadowed_bindings = Vec::with_capacity(num_bindings);
384 let mut mir_ids = Vec::with_capacity(num_bindings);
385 for (_name, id, _value, typ) in bindings.iter() {
386 let mir_id = mz_expr::LocalId::new(id_gen.allocate_id());
387 mir_ids.push(mir_id);
388 let shadowed = cte_map.insert(
389 id.clone(),
390 CteDesc {
391 new_id: mir_id,
392 relation_type: SqlRelationType::new(
393 outer_column_types
394 .iter()
395 .cloned()
396 .chain(typ.column_types.iter().cloned())
397 .collect::<Vec<_>>(),
398 ),
399 outer_relation: get_outer.clone(),
400 },
401 );
402 shadowed_bindings.push((*id, shadowed));
403 }
404
405 let mut mir_values = Vec::with_capacity(num_bindings);
406 for (_name, _id, value, _typ) in bindings.into_iter() {
407 mir_values.push(value.applied_to(
408 id_gen,
409 get_outer.clone(),
410 col_map,
411 cte_map,
412 context,
413 )?);
414 }
415
416 let mir_body = body.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
417
418 // Remove our bindings and reinstate any shadowed bindings.
419 for (id, shadowed) in shadowed_bindings {
420 if let Some(shadowed) = shadowed {
421 cte_map.insert(id, shadowed);
422 } else {
423 cte_map.remove(&id);
424 }
425 }
426
427 MirRelationExpr::LetRec {
428 ids: mir_ids,
429 values: mir_values,
430 // Copy the limit to each binding.
431 limits: repeat(limit).take(num_bindings).collect(),
432 body: Box::new(mir_body),
433 }
434 }
435 Project { input, outputs } => {
436 // Projections should be applied to the decorrelated `inner`, and to its columns,
437 // which means rebasing `outputs` to start `get_outer.arity()` columns later.
438 let input =
439 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
440 let outputs = (0..get_outer.arity())
441 .chain(outputs.into_iter().map(|i| get_outer.arity() + i))
442 .collect::<Vec<_>>();
443 input.project(outputs)
444 }
445 Map { input, mut scalars } => {
446 // Scalar expressions may contain correlated subqueries. We must be cautious!
447
448 // We lower scalars in chunks, and must keep track of the
449 // arity of the HIR fragments lowered so far.
450 let mut lowered_arity = input.arity();
451
452 let mut input =
453 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
454
455 // Lower subqueries in maximally sized batches, such as no subquery in the current
456 // batch depends on columns from the same batch.
457 // Note that subqueries in this projection may reference columns added by this
458 // Map operator, so we need to ensure these columns exist before lowering the
459 // subquery.
460 while !scalars.is_empty() {
461 let end_idx = scalars
462 .iter_mut()
463 .position(|s| {
464 let mut requires_nonexistent_column = false;
465 #[allow(deprecated)]
466 s.visit_columns(0, &mut |depth, col| {
467 if col.level == depth {
468 requires_nonexistent_column |= col.column >= lowered_arity
469 }
470 });
471 requires_nonexistent_column
472 })
473 .unwrap_or(scalars.len());
474 assert!(
475 end_idx > 0,
476 "a Map expression references itself or a later column; lowered_arity: {}, expressions: {:?}",
477 lowered_arity,
478 scalars
479 );
480
481 lowered_arity = lowered_arity + end_idx;
482 let scalars = scalars.drain(0..end_idx).collect_vec();
483
484 let old_arity = input.arity();
485 let (with_subqueries, subquery_map) = HirScalarExpr::lower_subqueries(
486 &scalars, id_gen, col_map, cte_map, input, context,
487 )?;
488 input = with_subqueries;
489
490 // We will proceed sequentially through the scalar expressions, for each transforming
491 // the decorrelated `input` into a relation with potentially more columns capable of
492 // addressing the needs of the scalar expression.
493 // Having done so, we add the scalar value of interest and trim off any other newly
494 // added columns.
495 //
496 // The sequential traversal is present as expressions are allowed to depend on the
497 // values of prior expressions.
498 let mut scalar_columns = Vec::new();
499 for scalar in scalars {
500 let scalar = scalar.applied_to(
501 id_gen,
502 col_map,
503 cte_map,
504 &mut input,
505 &Some(&subquery_map),
506 context,
507 )?;
508 input = input.map_one(scalar);
509 scalar_columns.push(input.arity() - 1);
510 }
511
512 // Discard any new columns added by the lowering of the scalar expressions
513 input = input.project((0..old_arity).chain(scalar_columns).collect());
514 }
515
516 input
517 }
518 CallTable { func, exprs } => {
519 // FlatMap expressions may contain correlated subqueries. Unlike Map they are not
520 // allowed to refer to the results of previous expressions, and we have a simpler
521 // implementation that appends all relevant columns first, then applies the flatmap
522 // operator to the result, then strips off any columns introduce by subqueries.
523
524 let mut input = get_outer;
525 let old_arity = input.arity();
526
527 let exprs = exprs
528 .into_iter()
529 .map(|e| e.applied_to(id_gen, col_map, cte_map, &mut input, &None, context))
530 .collect::<Result<Vec<_>, _>>()?;
531
532 let new_arity = input.arity();
533 let output_arity = func.output_arity();
534 input = input.flat_map(func, exprs);
535 if old_arity != new_arity {
536 // this means we added some columns to handle subqueries, and now we need to get rid of them
537 input = input.project(
538 (0..old_arity)
539 .chain(new_arity..new_arity + output_arity)
540 .collect(),
541 );
542 }
543 input
544 }
545 Filter { input, predicates } => {
546 // Filter expressions may contain correlated subqueries.
547 // We extend `get_outer` with sufficient values to determine the value of the predicate,
548 // then filter the results, then strip off any columns that were added for this purpose.
549 let mut input =
550 input.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
551 for predicate in predicates {
552 let old_arity = input.arity();
553 let predicate = predicate
554 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?;
555 let new_arity = input.arity();
556 input = input.filter(vec![predicate]);
557 if old_arity != new_arity {
558 // this means we added some columns to handle subqueries, and now we need to get rid of them
559 input = input.project((0..old_arity).collect());
560 }
561 }
562 input
563 }
564 Join {
565 left,
566 right,
567 on,
568 kind,
569 } if right.is_correlated() => {
570 // A correlated join is a join in which the right expression has
571 // access to the columns in the left expression. It turns out
572 // this is *exactly* our branch operator, plus some additional
573 // null handling in the case of left joins. (Right and full
574 // lateral joins are not permitted.)
575 //
576 // As with normal joins, the `on` predicate may be correlated,
577 // and we treat it as a filter that follows the branch.
578
579 assert!(kind.can_be_correlated());
580
581 let left = left.applied_to(id_gen, get_outer, col_map, cte_map, context)?;
582 left.let_in(id_gen, |id_gen, get_left| {
583 let apply_requires_distinct_outer = false;
584 let mut join = branch(
585 id_gen,
586 get_left.clone(),
587 col_map,
588 cte_map,
589 *right,
590 apply_requires_distinct_outer,
591 context,
592 |id_gen, right, get_left, col_map, cte_map, context| {
593 right.applied_to(id_gen, get_left, col_map, cte_map, context)
594 },
595 )?;
596
597 // Plan the `on` predicate.
598 let old_arity = join.arity();
599 let on =
600 on.applied_to(id_gen, col_map, cte_map, &mut join, &None, context)?;
601 join = join.filter(vec![on]);
602 let new_arity = join.arity();
603 if old_arity != new_arity {
604 // This means we added some columns to handle
605 // subqueries, and now we need to get rid of them.
606 join = join.project((0..old_arity).collect());
607 }
608
609 // If a left join, reintroduce any rows from the left that
610 // are missing, with nulls filled in for the right columns.
611 if let JoinKind::LeftOuter { .. } = kind {
612 let default = join
613 .typ()
614 .column_types
615 .into_iter()
616 .skip(get_left.arity())
617 .map(|typ| (Datum::Null, typ.scalar_type))
618 .collect();
619 get_left.lookup(id_gen, join, default)
620 } else {
621 Ok::<_, PlanError>(join)
622 }
623 })?
624 }
625 Join {
626 left,
627 right,
628 on,
629 kind,
630 } => {
631 if context.config.enable_variadic_left_join_lowering {
632 // Attempt to extract a stack of left joins.
633 if let JoinKind::LeftOuter = kind {
634 let mut rights = vec![(&*right, &on)];
635 let mut left_test = &left;
636 while let Join {
637 left,
638 right,
639 on,
640 kind: JoinKind::LeftOuter,
641 } = &**left_test
642 {
643 rights.push((&**right, on));
644 left_test = left;
645 }
646 if rights.len() > 1 {
647 // Defensively clone `cte_map` as it may be mutated.
648 let cte_map_clone = cte_map.clone();
649 if let Ok(Some(magic)) = variadic_left::attempt_left_join_magic(
650 left_test,
651 rights,
652 id_gen,
653 get_outer.clone(),
654 col_map,
655 cte_map,
656 context,
657 ) {
658 return Ok(magic);
659 } else {
660 cte_map.clone_from(&cte_map_clone);
661 }
662 }
663 }
664 }
665
666 // Both join expressions should be decorrelated, and then joined by their
667 // leading columns to form only those pairs corresponding to the same row
668 // of `get_outer`.
669 //
670 // The `on` predicate may contain correlated subqueries, and we treat it
671 // as though it was a filter, with the caveat that we also translate outer
672 // joins in this step. The post-filtration results need to be considered
673 // against the records present in the left and right (decorrelated) inputs,
674 // depending on the type of join.
675 let oa = get_outer.arity();
676 let left =
677 left.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
678 let lt = left.typ().column_types.into_iter().skip(oa).collect_vec();
679 let la = lt.len();
680 left.let_in(id_gen, |id_gen, get_left| {
681 let right_col_map = col_map.enter_scope(0);
682 let right = right.applied_to(
683 id_gen,
684 get_outer.clone(),
685 &right_col_map,
686 cte_map,
687 context,
688 )?;
689 let rt = right.typ().column_types.into_iter().skip(oa).collect_vec();
690 let ra = rt.len();
691 right.let_in(id_gen, |id_gen, get_right| {
692 let mut product = SR::join(
693 vec![get_left.clone(), get_right.clone()],
694 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
695 )
696 // Project away the repeated copy of get_outer's columns.
697 .project(
698 (0..(oa + la))
699 .chain((oa + la + oa)..(oa + la + oa + ra))
700 .collect(),
701 );
702
703 // Decorrelate and lower the `on` clause.
704 let on = on.applied_to(
705 id_gen,
706 col_map,
707 cte_map,
708 &mut product,
709 &None,
710 context,
711 )?;
712 // Collect the types of all subqueries appearing in
713 // the `on` clause. The subquery results were
714 // appended to `product` in the `on.applied_to(...)`
715 // call above.
716 let on_subquery_types = product
717 .typ()
718 .column_types
719 .drain(oa + la + ra..)
720 .collect_vec();
721 // Remember if `on` had any subqueries.
722 let on_has_subqueries = !on_subquery_types.is_empty();
723
724 // Attempt an efficient equijoin implementation, in which outer joins are
725 // more efficiently rendered than in general. This can return `None` if
726 // such a plan is not possible, for example if `on` does not describe an
727 // equijoin between columns of `left` and `right`.
728 if kind != JoinKind::Inner {
729 if let Some(joined) = attempt_outer_equijoin(
730 get_left.clone(),
731 get_right.clone(),
732 on.clone(),
733 on_subquery_types,
734 kind.clone(),
735 oa,
736 id_gen,
737 context,
738 )? {
739 if let Some(metrics) = context.metrics {
740 metrics.inc_outer_join_lowering("equi");
741 }
742 return Ok(joined);
743 }
744 }
745
746 // Otherwise, perform a more general join.
747 if let Some(metrics) = context.metrics {
748 metrics.inc_outer_join_lowering("general");
749 }
750 let mut join = product.filter(vec![on]);
751 if on_has_subqueries {
752 // This means that `on.applied_to(...)` appended
753 // some columns to handle subqueries, and now we
754 // need to get rid of them.
755 join = join.project((0..oa + la + ra).collect());
756 }
757 join.let_in(id_gen, |id_gen, get_join| {
758 let mut result = get_join.clone();
759 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } =
760 kind
761 {
762 let left_outer = get_left.clone().anti_lookup::<PlanError>(
763 id_gen,
764 get_join.clone(),
765 rt.into_iter()
766 .map(|typ| (Datum::Null, typ.scalar_type))
767 .collect(),
768 )?;
769 result = result.union(left_outer);
770 }
771 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
772 let right_outer = get_right
773 .clone()
774 .anti_lookup::<PlanError>(
775 id_gen,
776 get_join
777 // need to swap left and right to make the anti_lookup work
778 .project(
779 (0..oa)
780 .chain((oa + la)..(oa + la + ra))
781 .chain((oa)..(oa + la))
782 .collect(),
783 ),
784 lt.into_iter()
785 .map(|typ| (Datum::Null, typ.scalar_type))
786 .collect(),
787 )?
788 // swap left and right back again
789 .project(
790 (0..oa)
791 .chain((oa + ra)..(oa + ra + la))
792 .chain((oa)..(oa + ra))
793 .collect(),
794 );
795 result = result.union(right_outer);
796 }
797 Ok::<MirRelationExpr, PlanError>(result)
798 })
799 })
800 })?
801 }
802 Union { base, inputs } => {
803 // Union is uncomplicated.
804 SR::Union {
805 base: Box::new(base.applied_to(
806 id_gen,
807 get_outer.clone(),
808 col_map,
809 cte_map,
810 context,
811 )?),
812 inputs: inputs
813 .into_iter()
814 .map(|input| {
815 input.applied_to(
816 id_gen,
817 get_outer.clone(),
818 col_map,
819 cte_map,
820 context,
821 )
822 })
823 .collect::<Result<Vec<_>, _>>()?,
824 }
825 }
826 Reduce {
827 input,
828 group_key,
829 aggregates,
830 expected_group_size,
831 } => {
832 // Reduce may contain expressions with correlated subqueries.
833 // In addition, here an empty reduction key signifies that we need to supply default values
834 // in the case that there are no results (as in a SQL aggregation without an explicit GROUP BY).
835 let mut input =
836 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
837 let applied_group_key = (0..get_outer.arity())
838 .chain(group_key.iter().map(|i| get_outer.arity() + i))
839 .collect();
840 let applied_aggregates = aggregates
841 .into_iter()
842 .map(|aggregate| {
843 aggregate.applied_to(id_gen, col_map, cte_map, &mut input, context)
844 })
845 .collect::<Result<Vec<_>, _>>()?;
846 let input_type = input.typ();
847 let default = applied_aggregates
848 .iter()
849 .map(|agg| {
850 (
851 agg.func.default(),
852 agg.typ(&input_type.column_types).scalar_type,
853 )
854 })
855 .collect();
856 // NOTE we don't need to remove any extra columns from aggregate.applied_to above because the reduce will do that anyway
857 let mut reduced =
858 input.reduce(applied_group_key, applied_aggregates, expected_group_size);
859
860 // Introduce default values in the case the group key is empty.
861 if group_key.is_empty() {
862 reduced = get_outer.lookup::<PlanError>(id_gen, reduced, default)?;
863 }
864 reduced
865 }
866 Distinct { input } => {
867 // Distinct is uncomplicated.
868 input
869 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
870 .distinct()
871 }
872 TopK {
873 input,
874 group_key,
875 order_key,
876 limit,
877 offset,
878 expected_group_size,
879 } => {
880 // TopK is uncomplicated, except that we must group by the columns of `get_outer` as well.
881 let mut input =
882 input.applied_to(id_gen, get_outer.clone(), col_map, cte_map, context)?;
883 let mut applied_group_key: Vec<_> = (0..get_outer.arity())
884 .chain(group_key.iter().map(|i| get_outer.arity() + i))
885 .collect();
886 let applied_order_key = order_key
887 .iter()
888 .map(|column_order| ColumnOrder {
889 column: column_order.column + get_outer.arity(),
890 desc: column_order.desc,
891 nulls_last: column_order.nulls_last,
892 })
893 .collect();
894
895 let old_arity = input.arity();
896
897 // Lower `limit`, which may introduce new columns if is a correlated subquery.
898 let mut limit_mir = None;
899 if let Some(limit) = limit {
900 limit_mir = Some(
901 limit
902 .applied_to(id_gen, col_map, cte_map, &mut input, &None, context)?,
903 );
904 }
905
906 let new_arity = input.arity();
907 // Extend the key to contain any new columns.
908 applied_group_key.extend(old_arity..new_arity);
909
910 let offset = offset
911 .try_into_literal_int64()
912 .expect("Should be a Literal by this time")
913 .try_into()
914 .expect("Should have checked non-negativity of OFFSET clause already");
915 let mut result = input.top_k(
916 applied_group_key,
917 applied_order_key,
918 limit_mir,
919 offset,
920 expected_group_size,
921 );
922
923 // If new columns were added for `limit` we must remove them.
924 if old_arity != new_arity {
925 result = result.project((0..old_arity).collect());
926 }
927
928 result
929 }
930 Negate { input } => {
931 // Negate is uncomplicated.
932 input
933 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
934 .negate()
935 }
936 Threshold { input } => {
937 // Threshold is uncomplicated.
938 input
939 .applied_to(id_gen, get_outer, col_map, cte_map, context)?
940 .threshold()
941 }
942 })
943 })
944 }
945}
946
947impl HirScalarExpr {
948 /// Rewrite `self` into a `mz_expr::ScalarExpr` which can be applied to the modified `inner`.
949 ///
950 /// This method is responsible for decorrelating subqueries in `self` by introducing further columns
951 /// to `inner`, and rewriting `self` to refer to its physical columns (specified by `usize` positions).
952 /// The most complicated logic is for the scalar expressions that involve subqueries, each of which are
953 /// documented in more detail closer to their logic.
954 ///
955 /// This process presumes that `inner` is the result of decorrelation, meaning its first several columns
956 /// may be inherited from outer relations. The `col_map` column map should provide specific offsets where
957 /// each of these references can be found.
958 fn applied_to(
959 self,
960 id_gen: &mut mz_ore::id_gen::IdGen,
961 col_map: &ColumnMap,
962 cte_map: &mut CteMap,
963 inner: &mut MirRelationExpr,
964 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
965 context: &Context,
966 ) -> Result<MirScalarExpr, PlanError> {
967 maybe_grow(|| {
968 use MirScalarExpr as SS;
969
970 use HirScalarExpr::*;
971
972 if let Some(subquery_map) = subquery_map {
973 if let Some(col) = subquery_map.get(&self) {
974 return Ok(SS::column(*col));
975 }
976 }
977
978 Ok::<MirScalarExpr, PlanError>(match self {
979 Column(col_ref, name) => SS::Column(col_map.get(&col_ref), name),
980 Literal(row, typ, _name) => SS::Literal(Ok(row), typ),
981 Parameter(_, _name) => {
982 panic!("cannot decorrelate expression with unbound parameters")
983 }
984 CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
985 CallUnary {
986 func: func::UnaryFunc::CastVarCharToString(_),
987 expr,
988 name: _,
989 } if context.config.enable_cast_elimination => {
990 expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?
991 }
992 CallUnary {
993 func,
994 expr,
995 name: _,
996 } => SS::CallUnary {
997 func,
998 expr: Box::new(expr.applied_to(
999 id_gen,
1000 col_map,
1001 cte_map,
1002 inner,
1003 subquery_map,
1004 context,
1005 )?),
1006 },
1007 CallBinary {
1008 func,
1009 expr1,
1010 expr2,
1011 name: _,
1012 } => SS::CallBinary {
1013 func,
1014 expr1: Box::new(expr1.applied_to(
1015 id_gen,
1016 col_map,
1017 cte_map,
1018 inner,
1019 subquery_map,
1020 context,
1021 )?),
1022 expr2: Box::new(expr2.applied_to(
1023 id_gen,
1024 col_map,
1025 cte_map,
1026 inner,
1027 subquery_map,
1028 context,
1029 )?),
1030 },
1031 CallVariadic {
1032 func,
1033 exprs,
1034 name: _,
1035 } => SS::CallVariadic {
1036 func,
1037 exprs: exprs
1038 .into_iter()
1039 .map(|expr| {
1040 expr.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)
1041 })
1042 .collect::<Result<Vec<_>, _>>()?,
1043 },
1044 If {
1045 cond,
1046 then,
1047 els,
1048 name,
1049 } => {
1050 // The `If` case is complicated by the fact that we do not want to
1051 // apply the `then` or `else` logic to tuples that respectively do
1052 // not or do pass the `cond` test. Our strategy is to independently
1053 // decorrelate the `then` and `else` logic, and apply each to tuples
1054 // that respectively pass and do not pass the `cond` logic (which is
1055 // executed, and so decorrelated, for all tuples).
1056 //
1057 // Informally, we turn the `if` statement into:
1058 //
1059 // let then_case = inner.filter(cond).map(then);
1060 // let else_case = inner.filter(!cond).map(else);
1061 // return then_case.concat(else_case);
1062 //
1063 // We only require this if either expression would result in any
1064 // computation beyond the expr itself, which we will interpret as
1065 // "introduces additional columns". In the absence of correlation,
1066 // we should just retain a `ScalarExpr::If` expression; the inverse
1067 // transformation as above is complicated to recover after the fact,
1068 // and we would benefit from not introducing the complexity.
1069
1070 let inner_arity = inner.arity();
1071 let cond_expr =
1072 cond.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1073
1074 // Defensive copies, in case we mangle these in decorrelation.
1075 let inner_clone = inner.clone();
1076 let then_clone = then.clone();
1077 let else_clone = els.clone();
1078
1079 let cond_arity = inner.arity();
1080 let then_expr =
1081 then.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1082 let else_expr =
1083 els.applied_to(id_gen, col_map, cte_map, inner, subquery_map, context)?;
1084
1085 if cond_arity == inner.arity() {
1086 // If no additional columns were added, we simply return the
1087 // `If` variant with the updated expressions.
1088 SS::If {
1089 cond: Box::new(cond_expr),
1090 then: Box::new(then_expr),
1091 els: Box::new(else_expr),
1092 }
1093 } else {
1094 // If columns were added, we need a more careful approach, as
1095 // described above. First, we need to de-correlate each of
1096 // the two expressions independently, and apply their cases
1097 // as `MirRelationExpr::Map` operations.
1098
1099 *inner = inner_clone.let_in(id_gen, |id_gen, get_inner| {
1100 // Restrict to records satisfying `cond_expr` and apply `then` as a map.
1101 let mut then_inner = get_inner.clone().filter(vec![cond_expr.clone()]);
1102 let then_expr = then_clone.applied_to(
1103 id_gen,
1104 col_map,
1105 cte_map,
1106 &mut then_inner,
1107 subquery_map,
1108 context,
1109 )?;
1110 let then_arity = then_inner.arity();
1111 then_inner = then_inner
1112 .map_one(then_expr)
1113 .project((0..inner_arity).chain(Some(then_arity)).collect());
1114
1115 // Restrict to records not satisfying `cond_expr` and apply `els` as a map.
1116 let mut else_inner = get_inner.filter(vec![SS::CallVariadic {
1117 func: mz_expr::VariadicFunc::Or,
1118 exprs: vec![
1119 cond_expr.clone().call_binary(SS::literal_false(), func::Eq),
1120 cond_expr.clone().call_is_null(),
1121 ],
1122 }]);
1123 let else_expr = else_clone.applied_to(
1124 id_gen,
1125 col_map,
1126 cte_map,
1127 &mut else_inner,
1128 subquery_map,
1129 context,
1130 )?;
1131 let else_arity = else_inner.arity();
1132 else_inner = else_inner
1133 .map_one(else_expr)
1134 .project((0..inner_arity).chain(Some(else_arity)).collect());
1135
1136 // concatenate the two results.
1137 Ok::<MirRelationExpr, PlanError>(then_inner.union(else_inner))
1138 })?;
1139
1140 SS::Column(inner_arity, name)
1141 }
1142 }
1143
1144 // Subqueries!
1145 // These are surprisingly subtle. Things to be careful of:
1146
1147 // Anything in the subquery that cares about row counts (Reduce/Distinct/Negate/Threshold) must not:
1148 // * change the row counts of the outer query
1149 // * accidentally compute its own value using the row counts of the outer query
1150 // Use `branch` to calculate the subquery once for each __distinct__ key in the outer
1151 // query and then join the answers back on to the original rows of the outer query.
1152
1153 // When the subquery would return 0 rows for some row in the outer query, `subquery.applied_to(get_inner)` will not have any corresponding row.
1154 // Use `lookup` if you need to add default values for cases when the subquery returns 0 rows.
1155 Exists(expr, name) => {
1156 let apply_requires_distinct_outer = true;
1157 *inner = apply_existential_subquery(
1158 id_gen,
1159 inner.take_dangerous(),
1160 col_map,
1161 cte_map,
1162 *expr,
1163 apply_requires_distinct_outer,
1164 context,
1165 )?;
1166 SS::Column(inner.arity() - 1, name)
1167 }
1168
1169 Select(expr, name) => {
1170 let apply_requires_distinct_outer = true;
1171 *inner = apply_scalar_subquery(
1172 id_gen,
1173 inner.take_dangerous(),
1174 col_map,
1175 cte_map,
1176 *expr,
1177 apply_requires_distinct_outer,
1178 context,
1179 )?;
1180 SS::Column(inner.arity() - 1, name)
1181 }
1182 Windowing(expr, _name) => {
1183 let partition_by = expr.partition_by;
1184 let order_by = expr.order_by;
1185
1186 // argument lowering for scalar window functions
1187 // (We need to specify the & _ in the arguments because of this problem:
1188 // https://users.rust-lang.org/t/the-implementation-of-fnonce-is-not-general-enough/72141/3 )
1189 let scalar_lower_args =
1190 |_id_gen: &mut _,
1191 _col_map: &_,
1192 _cte_map: &mut _,
1193 _get_inner: &mut _,
1194 _subquery_map: &Option<&_>,
1195 order_by_mir: Vec<MirScalarExpr>,
1196 original_row_record,
1197 original_row_record_type: SqlScalarType| {
1198 let agg_input = MirScalarExpr::CallVariadic {
1199 func: mz_expr::VariadicFunc::ListCreate {
1200 elem_type: original_row_record_type.clone(),
1201 },
1202 exprs: vec![original_row_record],
1203 };
1204 let mut agg_input = vec![agg_input];
1205 agg_input.extend(order_by_mir.clone());
1206 let agg_input = MirScalarExpr::CallVariadic {
1207 func: mz_expr::VariadicFunc::RecordCreate {
1208 field_names: (0..agg_input.len())
1209 .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1210 .collect_vec(),
1211 },
1212 exprs: agg_input,
1213 };
1214 let list_type = SqlScalarType::List {
1215 element_type: Box::new(original_row_record_type),
1216 custom_id: None,
1217 };
1218 let agg_input_type = SqlScalarType::Record {
1219 fields: std::iter::once(&list_type)
1220 .map(|t| {
1221 (
1222 ColumnName::from(UNKNOWN_COLUMN_NAME),
1223 t.clone().nullable(false),
1224 )
1225 })
1226 .collect(),
1227 custom_id: None,
1228 }
1229 .nullable(false);
1230
1231 Ok((agg_input, agg_input_type))
1232 };
1233
1234 // argument lowering for value window functions and aggregate window functions
1235 let value_or_aggr_lower_args = |hir_encoded_args: Box<HirScalarExpr>| {
1236 |id_gen: &mut _,
1237 col_map: &_,
1238 cte_map: &mut _,
1239 get_inner: &mut _,
1240 subquery_map: &Option<&_>,
1241 order_by_mir: Vec<MirScalarExpr>,
1242 original_row_record,
1243 original_row_record_type| {
1244 // Creates [((OriginalRow, EncodedArgs), OrderByExprs...)]
1245
1246 // Compute the encoded args for all rows
1247 let mir_encoded_args = hir_encoded_args.applied_to(
1248 id_gen,
1249 col_map,
1250 cte_map,
1251 get_inner,
1252 subquery_map,
1253 context,
1254 )?;
1255 let mir_encoded_args_type = mir_encoded_args
1256 .typ(&get_inner.typ().column_types)
1257 .scalar_type;
1258
1259 // Build a new record that has two fields:
1260 // 1. the original row in a record
1261 // 2. the encoded args (which can be either a single value, or a record
1262 // if the window function has multiple arguments, such as `lag`)
1263 let fn_input_record_fields: Box<[_]> =
1264 [original_row_record_type, mir_encoded_args_type]
1265 .iter()
1266 .map(|t| {
1267 (
1268 ColumnName::from(UNKNOWN_COLUMN_NAME),
1269 t.clone().nullable(false),
1270 )
1271 })
1272 .collect();
1273 let fn_input_record = MirScalarExpr::CallVariadic {
1274 func: mz_expr::VariadicFunc::RecordCreate {
1275 field_names: fn_input_record_fields
1276 .iter()
1277 .map(|(n, _)| n.clone())
1278 .collect_vec(),
1279 },
1280 exprs: vec![original_row_record, mir_encoded_args],
1281 };
1282 let fn_input_record_type = SqlScalarType::Record {
1283 fields: fn_input_record_fields,
1284 custom_id: None,
1285 }
1286 .nullable(false);
1287
1288 // Build a new record with the record above + the ORDER BY exprs
1289 // This follows the standard encoding of ORDER BY exprs used by aggregate functions
1290 let mut agg_input = vec![fn_input_record];
1291 agg_input.extend(order_by_mir.clone());
1292 let agg_input = MirScalarExpr::CallVariadic {
1293 func: mz_expr::VariadicFunc::RecordCreate {
1294 field_names: (0..agg_input.len())
1295 .map(|_| ColumnName::from(UNKNOWN_COLUMN_NAME))
1296 .collect_vec(),
1297 },
1298 exprs: agg_input,
1299 };
1300
1301 let agg_input_type = SqlScalarType::Record {
1302 fields: [(
1303 ColumnName::from(UNKNOWN_COLUMN_NAME),
1304 fn_input_record_type.nullable(false),
1305 )]
1306 .into(),
1307 custom_id: None,
1308 }
1309 .nullable(false);
1310
1311 Ok((agg_input, agg_input_type))
1312 }
1313 };
1314
1315 match expr.func {
1316 WindowExprType::Scalar(scalar_window_expr) => {
1317 let mir_aggr_func = scalar_window_expr.into_expr();
1318 Self::window_func_applied_to(
1319 id_gen,
1320 col_map,
1321 cte_map,
1322 inner,
1323 subquery_map,
1324 partition_by,
1325 order_by,
1326 mir_aggr_func,
1327 scalar_lower_args,
1328 context,
1329 )?
1330 }
1331 WindowExprType::Value(value_window_expr) => {
1332 let (hir_encoded_args, mir_aggr_func) = value_window_expr.into_expr();
1333
1334 Self::window_func_applied_to(
1335 id_gen,
1336 col_map,
1337 cte_map,
1338 inner,
1339 subquery_map,
1340 partition_by,
1341 order_by,
1342 mir_aggr_func,
1343 value_or_aggr_lower_args(hir_encoded_args),
1344 context,
1345 )?
1346 }
1347 WindowExprType::Aggregate(aggr_window_expr) => {
1348 let (hir_encoded_args, mir_aggr_func) = aggr_window_expr.into_expr();
1349
1350 Self::window_func_applied_to(
1351 id_gen,
1352 col_map,
1353 cte_map,
1354 inner,
1355 subquery_map,
1356 partition_by,
1357 order_by,
1358 mir_aggr_func,
1359 value_or_aggr_lower_args(hir_encoded_args),
1360 context,
1361 )?
1362 }
1363 }
1364 }
1365 })
1366 })
1367 }
1368
1369 fn window_func_applied_to<F>(
1370 id_gen: &mut mz_ore::id_gen::IdGen,
1371 col_map: &ColumnMap,
1372 cte_map: &mut CteMap,
1373 inner: &mut MirRelationExpr,
1374 subquery_map: &Option<&BTreeMap<HirScalarExpr, usize>>,
1375 partition_by: Vec<HirScalarExpr>,
1376 order_by: Vec<HirScalarExpr>,
1377 mir_aggr_func: AggregateFunc,
1378 lower_args: F,
1379 context: &Context,
1380 ) -> Result<MirScalarExpr, PlanError>
1381 where
1382 F: FnOnce(
1383 &mut mz_ore::id_gen::IdGen,
1384 &ColumnMap,
1385 &mut CteMap,
1386 &mut MirRelationExpr,
1387 &Option<&BTreeMap<HirScalarExpr, usize>>,
1388 Vec<MirScalarExpr>,
1389 MirScalarExpr,
1390 SqlScalarType,
1391 ) -> Result<(MirScalarExpr, SqlColumnType), PlanError>,
1392 {
1393 // Example MIRs for a window function (specifically, a window aggregation):
1394 //
1395 // CREATE TABLE t7(x INT, y INT);
1396 //
1397 // explain decorrelated plan for select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1398 //
1399 // Decorrelated Plan
1400 // Project (#3)
1401 // Map (#2)
1402 // Project (#3..=#5)
1403 // Map (record_get[0](record_get[1](#2)), record_get[1](record_get[1](#2)), record_get[0](#2))
1404 // FlatMap unnest_list(#1)
1405 // Reduce group_by=[#2] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1406 // Map ((#0 + #1))
1407 // CrossJoin
1408 // Constant
1409 // - ()
1410 // Get materialize.public.t7
1411 //
1412 // The same query after optimizations:
1413 //
1414 // explain select sum(x*y) over (partition by x+y order by x-y, x/y) from t7;
1415 //
1416 // Optimized Plan
1417 // Explained Query:
1418 // Project (#2)
1419 // Map (record_get[0](#1))
1420 // FlatMap unnest_list(#0)
1421 // Project (#1)
1422 // Reduce group_by=[(#0 + #1)] aggregates=[window_agg[sum order_by=[#0 asc nulls_last, #1 asc nulls_last]](row(row(row(#0, #1), (#0 * #1)), (#0 - #1), (#0 / #1)))]
1423 // ReadStorage materialize.public.t7
1424 //
1425 // The `row(row(row(...), ...), ...)` stuff means the following:
1426 // `row(row(row(<original row>), <arguments to window function>), <order by values>...)`
1427 // - The <arguments to window function> can be either a single value or itself a
1428 // `row` if there are multiple arguments.
1429 // - The <order by values> are _not_ wrapped in a `row`, even if there are more than one
1430 // ORDER BY columns.
1431 // - The <original row> currently always captures the entire original row. This should
1432 // improve when we make `ProjectionPushdown` smarter, see
1433 // https://github.com/MaterializeInc/database-issues/issues/5090
1434 //
1435 // TODO:
1436 // We should probably introduce some dedicated Datum constructor functions instead of `row`
1437 // to make MIR plans and MIR construction/manipulation code more readable. Additionally, we
1438 // might even introduce dedicated Datum enum variants, so that the rendering code also
1439 // becomes more readable (and possibly slightly more performant).
1440
1441 *inner = inner
1442 .take_dangerous()
1443 .let_in(id_gen, |id_gen, mut get_inner| {
1444 let order_by_mir = order_by
1445 .into_iter()
1446 .map(|o| {
1447 o.applied_to(
1448 id_gen,
1449 col_map,
1450 cte_map,
1451 &mut get_inner,
1452 subquery_map,
1453 context,
1454 )
1455 })
1456 .collect::<Result<Vec<_>, _>>()?;
1457
1458 // Record input arity here so that any group_keys that need to mutate get_inner
1459 // don't add those columns to the aggregate input.
1460 let input_type = get_inner.typ();
1461 let input_arity = input_type.arity();
1462 // The reduction that computes the window function must be keyed on the columns
1463 // from the outer context, plus the expressions in the partition key. The current
1464 // subquery will be 'executed' for every distinct row from the outer context so
1465 // by putting the outer columns in the grouping key we isolate each re-execution.
1466 let mut group_key = col_map
1467 .inner
1468 .iter()
1469 .map(|(_, outer_col)| *outer_col)
1470 .sorted()
1471 .collect_vec();
1472 for p in partition_by {
1473 let key = p.applied_to(
1474 id_gen,
1475 col_map,
1476 cte_map,
1477 &mut get_inner,
1478 subquery_map,
1479 context,
1480 )?;
1481 if let MirScalarExpr::Column(c, _name) = key {
1482 group_key.push(c);
1483 } else {
1484 get_inner = get_inner.map_one(key);
1485 group_key.push(get_inner.arity() - 1);
1486 }
1487 }
1488
1489 get_inner.let_in(id_gen, |id_gen, mut get_inner| {
1490 // Original columns of the relation
1491 let fields: Box<_> = input_type
1492 .column_types
1493 .iter()
1494 .take(input_arity)
1495 .map(|t| (ColumnName::from(UNKNOWN_COLUMN_NAME), t.clone()))
1496 .collect();
1497
1498 // Original row made into a record
1499 let original_row_record = MirScalarExpr::CallVariadic {
1500 func: mz_expr::VariadicFunc::RecordCreate {
1501 field_names: fields.iter().map(|(name, _)| name.clone()).collect_vec(),
1502 },
1503 exprs: (0..input_arity).map(MirScalarExpr::column).collect_vec(),
1504 };
1505 let original_row_record_type = SqlScalarType::Record {
1506 fields,
1507 custom_id: None,
1508 };
1509
1510 let (agg_input, agg_input_type) = lower_args(
1511 id_gen,
1512 col_map,
1513 cte_map,
1514 &mut get_inner,
1515 subquery_map,
1516 order_by_mir,
1517 original_row_record,
1518 original_row_record_type,
1519 )?;
1520
1521 let aggregate = mz_expr::AggregateExpr {
1522 func: mir_aggr_func,
1523 expr: agg_input,
1524 distinct: false,
1525 };
1526
1527 // Actually call reduce with the window function
1528 // The output of the aggregation function should be a list of tuples that has
1529 // the result in the first position, and the original row in the second position
1530 let mut reduce = get_inner
1531 .reduce(group_key.clone(), vec![aggregate.clone()], None)
1532 .flat_map(
1533 mz_expr::TableFunc::UnnestList {
1534 el_typ: aggregate
1535 .func
1536 .output_type(agg_input_type)
1537 .scalar_type
1538 .unwrap_list_element_type()
1539 .clone(),
1540 },
1541 vec![MirScalarExpr::column(group_key.len())],
1542 );
1543 let record_col = reduce.arity() - 1;
1544
1545 // Unpack the record output by the window function
1546 for c in 0..input_arity {
1547 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1548 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(c)),
1549 expr: Box::new(MirScalarExpr::CallUnary {
1550 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(1)),
1551 expr: Box::new(MirScalarExpr::column(record_col)),
1552 }),
1553 });
1554 }
1555
1556 // Append the column with the result of the window function.
1557 reduce = reduce.take_dangerous().map_one(MirScalarExpr::CallUnary {
1558 func: mz_expr::UnaryFunc::RecordGet(mz_expr::func::RecordGet(0)),
1559 expr: Box::new(MirScalarExpr::column(record_col)),
1560 });
1561
1562 let agg_col = record_col + 1 + input_arity;
1563 Ok::<_, PlanError>(reduce.project((record_col + 1..agg_col + 1).collect_vec()))
1564 })
1565 })?;
1566 Ok(MirScalarExpr::column(inner.arity() - 1))
1567 }
1568
1569 /// Applies the subqueries in the given list of scalar expressions to every distinct
1570 /// value of the given relation and returns a join of the given relation with all
1571 /// the subqueries found, and the mapping of scalar expressions with columns projected
1572 /// by the returned join that will hold their results.
1573 fn lower_subqueries(
1574 exprs: &[Self],
1575 id_gen: &mut mz_ore::id_gen::IdGen,
1576 col_map: &ColumnMap,
1577 cte_map: &mut CteMap,
1578 inner: MirRelationExpr,
1579 context: &Context,
1580 ) -> Result<(MirRelationExpr, BTreeMap<HirScalarExpr, usize>), PlanError> {
1581 let mut subquery_map = BTreeMap::new();
1582 let output = inner.let_in(id_gen, |id_gen, get_inner| {
1583 let mut subqueries = Vec::new();
1584 let distinct_inner = get_inner.clone().distinct();
1585 for expr in exprs.iter() {
1586 expr.visit_pre_post(
1587 &mut |e| match e {
1588 // For simplicity, subqueries within a conditional statement will be
1589 // lowered when lowering the conditional expression.
1590 HirScalarExpr::If { .. } => Some(vec![]),
1591 _ => None,
1592 },
1593 &mut |e| match e {
1594 HirScalarExpr::Select(expr, _name) => {
1595 let apply_requires_distinct_outer = false;
1596 let subquery = apply_scalar_subquery(
1597 id_gen,
1598 distinct_inner.clone(),
1599 col_map,
1600 cte_map,
1601 (**expr).clone(),
1602 apply_requires_distinct_outer,
1603 context,
1604 )
1605 .unwrap();
1606
1607 subqueries.push((e.clone(), subquery));
1608 }
1609 HirScalarExpr::Exists(expr, _name) => {
1610 let apply_requires_distinct_outer = false;
1611 let subquery = apply_existential_subquery(
1612 id_gen,
1613 distinct_inner.clone(),
1614 col_map,
1615 cte_map,
1616 (**expr).clone(),
1617 apply_requires_distinct_outer,
1618 context,
1619 )
1620 .unwrap();
1621 subqueries.push((e.clone(), subquery));
1622 }
1623 _ => {}
1624 },
1625 )?;
1626 }
1627
1628 if subqueries.is_empty() {
1629 Ok::<MirRelationExpr, PlanError>(get_inner)
1630 } else {
1631 let inner_arity = get_inner.arity();
1632 let mut total_arity = inner_arity;
1633 let mut join_inputs = vec![get_inner];
1634 let mut join_input_arities = vec![inner_arity];
1635 for (expr, subquery) in subqueries.into_iter() {
1636 // Avoid lowering duplicated subqueries
1637 if !subquery_map.contains_key(&expr) {
1638 let subquery_arity = subquery.arity();
1639 assert_eq!(subquery_arity, inner_arity + 1);
1640 join_inputs.push(subquery);
1641 join_input_arities.push(subquery_arity);
1642 total_arity += subquery_arity;
1643
1644 // Column with the value of the subquery
1645 subquery_map.insert(expr, total_arity - 1);
1646 }
1647 }
1648 // Each subquery projects all the columns of the outer context (distinct_inner)
1649 // plus 1 column, containing the result of the subquery. Those columns must be
1650 // joined with the outer/main relation (get_inner).
1651 let input_mapper =
1652 mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
1653 let equivalences = (0..inner_arity)
1654 .map(|col| {
1655 join_inputs
1656 .iter()
1657 .enumerate()
1658 .map(|(input, _)| {
1659 MirScalarExpr::column(input_mapper.map_column_to_global(col, input))
1660 })
1661 .collect_vec()
1662 })
1663 .collect_vec();
1664 Ok(MirRelationExpr::join_scalars(join_inputs, equivalences))
1665 }
1666 })?;
1667 Ok((output, subquery_map))
1668 }
1669
1670 /// Rewrites `self` into a `mz_expr::ScalarExpr`.
1671 ///
1672 /// Returns an _internal_ error if the expression contains
1673 /// - a subquery
1674 /// - a column reference to an outer level
1675 /// - a parameter
1676 /// - a window function call
1677 ///
1678 /// Should succeed if [`HirScalarExpr::is_constant`] would return true on `self`.
1679 ///
1680 /// Set `enable_cast_elimination` to remove casts that are noops in MIR.
1681 pub fn lower_uncorrelated<C: Into<Config>>(
1682 self,
1683 config: C,
1684 ) -> Result<MirScalarExpr, PlanError> {
1685 let config = config.into();
1686
1687 use MirScalarExpr as SS;
1688
1689 use HirScalarExpr::*;
1690
1691 Ok(match self {
1692 Column(ColumnRef { level: 0, column }, name) => SS::Column(column, name),
1693 Literal(datum, typ, _name) => SS::Literal(Ok(datum), typ),
1694 CallUnmaterializable(func, _name) => SS::CallUnmaterializable(func),
1695 CallUnary {
1696 func: func::UnaryFunc::CastVarCharToString(_),
1697 expr,
1698 name: _,
1699 } if config.enable_cast_elimination => expr.lower_uncorrelated(config)?,
1700 CallUnary {
1701 func,
1702 expr,
1703 name: _,
1704 } => SS::CallUnary {
1705 func,
1706 expr: Box::new(expr.lower_uncorrelated(config)?),
1707 },
1708 CallBinary {
1709 func,
1710 expr1,
1711 expr2,
1712 name: _,
1713 } => SS::CallBinary {
1714 func,
1715 expr1: Box::new(expr1.lower_uncorrelated(config)?),
1716 expr2: Box::new(expr2.lower_uncorrelated(config)?),
1717 },
1718 CallVariadic {
1719 func,
1720 exprs,
1721 name: _,
1722 } => SS::CallVariadic {
1723 func,
1724 exprs: exprs
1725 .into_iter()
1726 .map(|expr| expr.lower_uncorrelated(config))
1727 .collect::<Result<_, _>>()?,
1728 },
1729 If {
1730 cond,
1731 then,
1732 els,
1733 name: _,
1734 } => SS::If {
1735 cond: Box::new(cond.lower_uncorrelated(config)?),
1736 then: Box::new(then.lower_uncorrelated(config)?),
1737 els: Box::new(els.lower_uncorrelated(config)?),
1738 },
1739 Select { .. } | Exists { .. } | Parameter(..) | Column(..) | Windowing(..) => {
1740 sql_bail!(
1741 "Internal error: unexpected HirScalarExpr in lower_uncorrelated: {:?}",
1742 self
1743 );
1744 }
1745 })
1746 }
1747}
1748
1749/// Prepare to apply `inner` to `outer`. Note that `inner` is a correlated (SQL)
1750/// expression, while `outer` is a non-correlated (dataflow) expression. `inner`
1751/// will, in effect, be executed once for every distinct row in `outer`, and the
1752/// results will be joined with `outer`. Note that columns in `outer` that are
1753/// not depended upon by `inner` are thrown away before the distinct, so that we
1754/// don't perform needless computation of `inner`.
1755///
1756/// `branch` will inspect the contents of `inner` to determine whether `inner`
1757/// is not multiplicity sensitive (roughly, contains only maps, filters,
1758/// projections, and calls to table functions). If it is not multiplicity
1759/// sensitive, `branch` will *not* distinctify outer. If this is problematic,
1760/// e.g. because the `apply` callback itself introduces multiplicity-sensitive
1761/// operations that were not present in `inner`, then set
1762/// `apply_requires_distinct_outer` to ensure that `branch` chooses the plan
1763/// that distinctifies `outer`.
1764///
1765/// The caller must supply the `apply` function that applies the rewritten
1766/// `inner` to `outer`.
1767fn branch<F>(
1768 id_gen: &mut mz_ore::id_gen::IdGen,
1769 outer: MirRelationExpr,
1770 col_map: &ColumnMap,
1771 cte_map: &mut CteMap,
1772 inner: HirRelationExpr,
1773 apply_requires_distinct_outer: bool,
1774 context: &Context,
1775 apply: F,
1776) -> Result<MirRelationExpr, PlanError>
1777where
1778 F: FnOnce(
1779 &mut mz_ore::id_gen::IdGen,
1780 HirRelationExpr,
1781 MirRelationExpr,
1782 &ColumnMap,
1783 &mut CteMap,
1784 &Context,
1785 ) -> Result<MirRelationExpr, PlanError>,
1786{
1787 // TODO: It would be nice to have a version of this code w/o optimizations,
1788 // at the least for purposes of understanding. It was difficult for one reader
1789 // to understand the required properties of `outer` and `col_map`.
1790
1791 // If the inner expression is sufficiently simple, it is safe to apply it
1792 // *directly* to outer, rather than applying it to the distinctified key
1793 // (see below).
1794 //
1795 // As an example, consider the following two queries:
1796 //
1797 // CREATE TABLE t (a int, b int);
1798 // SELECT a, series FROM t, generate_series(1, t.b) series;
1799 //
1800 // The "simple" path for the `SELECT` yields
1801 //
1802 // %0 =
1803 // | Get t
1804 // | FlatMap generate_series(1, #1)
1805 //
1806 // while the non-simple path yields:
1807 //
1808 // %0 =
1809 // | Get t
1810 //
1811 // %1 =
1812 // | Get t
1813 // | Distinct group=(#1)
1814 // | FlatMap generate_series(1, #0)
1815 //
1816 // %2 =
1817 // | LeftJoin %1 %2 (= #1 #2)
1818 //
1819 // There is a tradeoff here: the simple plan is stateless, but the non-
1820 // simple plan may do (much) less computation if there are only a few
1821 // distinct values of `t.b`.
1822 //
1823 // We apply a very simple heuristic here and take the simple path if `inner`
1824 // contains only maps, filters, projections, and calls to table functions.
1825 // The intuition is that straightforward usage of table functions should
1826 // take the simple path, while everything else should not. (In theory we
1827 // think this transformation is valid as long as `inner` does not contain a
1828 // Reduce, Distinct, or TopK node, but it is not always an optimization in
1829 // the general case.)
1830 //
1831 // TODO(benesch): this should all be handled by a proper optimizer, but
1832 // detecting the moment of decorrelation in the optimizer right now is too
1833 // hard.
1834 let mut is_simple = true;
1835 #[allow(deprecated)]
1836 inner.visit(0, &mut |expr, _| match expr {
1837 HirRelationExpr::Constant { .. }
1838 | HirRelationExpr::Project { .. }
1839 | HirRelationExpr::Map { .. }
1840 | HirRelationExpr::Filter { .. }
1841 | HirRelationExpr::CallTable { .. } => (),
1842 _ => is_simple = false,
1843 });
1844 if is_simple && !apply_requires_distinct_outer {
1845 let new_col_map = col_map.enter_scope(outer.arity() - col_map.len());
1846 return outer.let_in(id_gen, |id_gen, get_outer| {
1847 apply(id_gen, inner, get_outer, &new_col_map, cte_map, context)
1848 });
1849 }
1850
1851 // The key consists of the columns from the outer expression upon which the
1852 // inner relation depends. We discover these dependencies by walking the
1853 // inner relation expression and looking for column references whose level
1854 // escapes inner.
1855 //
1856 // At the end of this process, `key` contains the decorrelated position of
1857 // each outer column, according to the passed-in `col_map`, and
1858 // `new_col_map` maps each outer column to its new ordinal position in key.
1859 let mut outer_cols = BTreeSet::new();
1860 #[allow(deprecated)]
1861 inner.visit_columns(0, &mut |depth, col| {
1862 // Test if the column reference escapes the subquery.
1863 if col.level > depth {
1864 outer_cols.insert(ColumnRef {
1865 level: col.level - depth,
1866 column: col.column,
1867 });
1868 }
1869 });
1870 // Collect all the outer columns referenced by any CTE referenced by
1871 // the inner relation.
1872 #[allow(deprecated)]
1873 inner.visit(0, &mut |e, _| match e {
1874 HirRelationExpr::Get {
1875 id: mz_expr::Id::Local(id),
1876 ..
1877 } => {
1878 if let Some(cte_desc) = cte_map.get(id) {
1879 let cte_outer_arity = cte_desc.outer_relation.arity();
1880 outer_cols.extend(
1881 col_map
1882 .inner
1883 .iter()
1884 .filter(|(_, position)| **position < cte_outer_arity)
1885 .map(|(c, _)| {
1886 // `col_map` maps column references to column positions in
1887 // `outer`'s projection.
1888 // `outer_cols` is meant to contain the external column
1889 // references in `inner`.
1890 // Since `inner` defines a new scope, any column reference
1891 // in `col_map` is one level deeper when seen from within
1892 // `inner`, hence the +1.
1893 ColumnRef {
1894 level: c.level + 1,
1895 column: c.column,
1896 }
1897 }),
1898 );
1899 }
1900 }
1901 HirRelationExpr::Let { id, .. } => {
1902 // Note: if ID uniqueness is not guaranteed, we can't use `visit` since
1903 // we would need to remove the old CTE with the same ID temporarily while
1904 // traversing the definition of the new CTE under the same ID.
1905 assert!(!cte_map.contains_key(id));
1906 }
1907 _ => {}
1908 });
1909 let mut new_col_map = BTreeMap::new();
1910 let mut key = vec![];
1911 for col in outer_cols {
1912 new_col_map.insert(col, key.len());
1913 key.push(col_map.get(&ColumnRef {
1914 // Note: `outer_cols` contains the external column references within `inner`.
1915 // We must compensate for `inner`'s scope when translating column references
1916 // as seen within `inner` to column references as seen from `outer`'s context,
1917 // hence the -1.
1918 level: col.level - 1,
1919 column: col.column,
1920 }));
1921 }
1922 let new_col_map = ColumnMap::new(new_col_map);
1923 outer.let_in(id_gen, |id_gen, get_outer| {
1924 let keyed_outer = if key.is_empty() {
1925 // Don't depend on outer at all if the branch is not correlated,
1926 // which yields vastly better query plans. Note that this is a bit
1927 // weird in that the branch will be computed even if outer has no
1928 // rows, whereas if it had been correlated it would not (and *could*
1929 // not) have been computed if outer had no rows, but the callers of
1930 // this function don't mind these somewhat-weird semantics.
1931 MirRelationExpr::constant(vec![vec![]], SqlRelationType::new(vec![]))
1932 } else {
1933 get_outer.clone().distinct_by(key.clone())
1934 };
1935 keyed_outer.let_in(id_gen, |id_gen, get_keyed_outer| {
1936 let oa = get_outer.arity();
1937 let branch = apply(
1938 id_gen,
1939 inner,
1940 get_keyed_outer,
1941 &new_col_map,
1942 cte_map,
1943 context,
1944 )?;
1945 let ba = branch.arity();
1946 let joined = MirRelationExpr::join(
1947 vec![get_outer.clone(), branch],
1948 key.iter()
1949 .enumerate()
1950 .map(|(i, &k)| vec![(0, k), (1, i)])
1951 .collect(),
1952 )
1953 // throw away the right-hand copy of the key we just joined on
1954 .project((0..oa).chain((oa + key.len())..(oa + ba)).collect());
1955 Ok(joined)
1956 })
1957 })
1958}
1959
1960fn apply_scalar_subquery(
1961 id_gen: &mut mz_ore::id_gen::IdGen,
1962 outer: MirRelationExpr,
1963 col_map: &ColumnMap,
1964 cte_map: &mut CteMap,
1965 scalar_subquery: HirRelationExpr,
1966 apply_requires_distinct_outer: bool,
1967 context: &Context,
1968) -> Result<MirRelationExpr, PlanError> {
1969 branch(
1970 id_gen,
1971 outer,
1972 col_map,
1973 cte_map,
1974 scalar_subquery,
1975 apply_requires_distinct_outer,
1976 context,
1977 |id_gen, expr, get_inner, col_map, cte_map, context| {
1978 // compute for every row in get_inner
1979 let select = expr.applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?;
1980 let col_type = select.typ().column_types.into_last();
1981
1982 let inner_arity = get_inner.arity();
1983 // We must determine a count for each `get_inner` prefix,
1984 // and report an error if that count exceeds one.
1985 let guarded = select.let_in(id_gen, |_id_gen, get_select| {
1986 // Count for each `get_inner` prefix.
1987 let counts = get_select.clone().reduce(
1988 (0..inner_arity).collect::<Vec<_>>(),
1989 vec![mz_expr::AggregateExpr {
1990 func: mz_expr::AggregateFunc::Count,
1991 expr: MirScalarExpr::literal_true(),
1992 distinct: false,
1993 }],
1994 None,
1995 );
1996
1997 let use_guard = context.config.enable_guard_subquery_tablefunc;
1998
1999 // Errors should result from counts > 1.
2000 let errors = if use_guard {
2001 counts
2002 .flat_map(
2003 mz_expr::TableFunc::GuardSubquerySize {
2004 column_type: col_type.clone().scalar_type,
2005 },
2006 vec![MirScalarExpr::column(inner_arity)],
2007 )
2008 .project(
2009 (0..inner_arity)
2010 .chain(Some(inner_arity + 1))
2011 .collect::<Vec<_>>(),
2012 )
2013 } else {
2014 counts
2015 .filter(vec![MirScalarExpr::column(inner_arity).call_binary(
2016 MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
2017 func::Gt,
2018 )])
2019 .project((0..inner_arity).collect::<Vec<_>>())
2020 .map_one(MirScalarExpr::literal(
2021 Err(mz_expr::EvalError::MultipleRowsFromSubquery),
2022 col_type.clone().scalar_type,
2023 ))
2024 };
2025 // Return `get_select` and any errors added in.
2026 Ok::<_, PlanError>(get_select.union(errors))
2027 })?;
2028 // append Null to anything that didn't return any rows
2029 let default = vec![(Datum::Null, col_type.scalar_type)];
2030 get_inner.lookup(id_gen, guarded, default)
2031 },
2032 )
2033}
2034
2035fn apply_existential_subquery(
2036 id_gen: &mut mz_ore::id_gen::IdGen,
2037 outer: MirRelationExpr,
2038 col_map: &ColumnMap,
2039 cte_map: &mut CteMap,
2040 subquery_expr: HirRelationExpr,
2041 apply_requires_distinct_outer: bool,
2042 context: &Context,
2043) -> Result<MirRelationExpr, PlanError> {
2044 branch(
2045 id_gen,
2046 outer,
2047 col_map,
2048 cte_map,
2049 subquery_expr,
2050 apply_requires_distinct_outer,
2051 context,
2052 |id_gen, expr, get_inner, col_map, cte_map, context| {
2053 let exists = expr
2054 // compute for every row in get_inner
2055 .applied_to(id_gen, get_inner.clone(), col_map, cte_map, context)?
2056 // throw away actual values and just remember whether or not there were __any__ rows
2057 .distinct_by((0..get_inner.arity()).collect())
2058 // Append true to anything that returned any rows.
2059 .map(vec![MirScalarExpr::literal_true()]);
2060
2061 // append False to anything that didn't return any rows
2062 get_inner.lookup(id_gen, exists, vec![(Datum::False, SqlScalarType::Bool)])
2063 },
2064 )
2065}
2066
2067impl AggregateExpr {
2068 fn applied_to(
2069 self,
2070 id_gen: &mut mz_ore::id_gen::IdGen,
2071 col_map: &ColumnMap,
2072 cte_map: &mut CteMap,
2073 inner: &mut MirRelationExpr,
2074 context: &Context,
2075 ) -> Result<mz_expr::AggregateExpr, PlanError> {
2076 let AggregateExpr {
2077 func,
2078 expr,
2079 distinct,
2080 } = self;
2081
2082 Ok(mz_expr::AggregateExpr {
2083 func: func.into_expr(),
2084 expr: expr.applied_to(id_gen, col_map, cte_map, inner, &None, context)?,
2085 distinct,
2086 })
2087 }
2088}
2089
2090/// Attempts an efficient outer join, if `on` has equijoin structure.
2091///
2092/// Both `left` and `right` are decorrelated inputs.
2093///
2094/// The first `oa` columns correspond to an outer context: we should do the
2095/// outer join independently for each prefix. In the case that `on` contains
2096/// just some equality tests between columns of `left` and `right` and some
2097/// local predicates, we can employ a relatively simple plan.
2098///
2099/// The last `on_subquery_types.len()` columns correspond to results from
2100/// subqueries defined in the `on` clause - we treat those as theta-join
2101/// conditions that prohibit the use of the simple plan attempted here.
2102fn attempt_outer_equijoin(
2103 left: MirRelationExpr,
2104 right: MirRelationExpr,
2105 on: MirScalarExpr,
2106 on_subquery_types: Vec<SqlColumnType>,
2107 kind: JoinKind,
2108 oa: usize,
2109 id_gen: &mut mz_ore::id_gen::IdGen,
2110 context: &Context,
2111) -> Result<Option<MirRelationExpr>, PlanError> {
2112 // TODO(database-issues#6827): In theory, we can be smarter and also handle `on`
2113 // predicates that reference subqueries as long as these subqueries don't
2114 // reference `left` and `right` at the same time.
2115 //
2116 // TODO(database-issues#6828): This code can be improved as follows:
2117 //
2118 // 1. Move the `canonicalize_predicates(...)` call to `applied_to`.
2119 // 2. Use the canonicalized `on` predicate in the non-equijoin based
2120 // lowering strategy.
2121 // 3. Move the `OnPredicates::new(...)` call to `applied_to`.
2122 // 4. Pass the classified `OnPredicates` as a parameter.
2123 // 5. Guard calls of this function with `on_predicates.is_equijoin()`.
2124 //
2125 // Steps (1 + 2) require further investigation because we might change the
2126 // error semantics in case the `on` predicate contains a literal error..
2127
2128 let l_type = left.typ();
2129 let r_type = right.typ();
2130 let la = l_type.column_types.len() - oa;
2131 let ra = r_type.column_types.len() - oa;
2132 let sa = on_subquery_types.len();
2133
2134 // The output type contains [outer, left, right, sa] attributes.
2135 let mut output_type = Vec::with_capacity(oa + la + ra + sa);
2136 output_type.extend(l_type.column_types);
2137 output_type.extend(r_type.column_types.into_iter().skip(oa));
2138 output_type.extend(on_subquery_types);
2139
2140 // Generally healthy to do, but specifically `USING` conditions sometimes
2141 // put an `AND true` at the end of the `ON` condition.
2142 //
2143 // TODO(aalexandrov): maybe we should already be doing this in `applied_to`.
2144 // However, in that case it's not clear that we won't see regressions if
2145 // `on` simplifies to a literal error.
2146 let mut on = vec![on];
2147 mz_expr::canonicalize::canonicalize_predicates(&mut on, &output_type);
2148
2149 // Form the left and right types without the outer attributes.
2150 output_type.drain(0..oa);
2151 let lt = output_type.drain(0..la).collect_vec();
2152 let rt = output_type.drain(0..ra).collect_vec();
2153 assert!(output_type.len() == sa);
2154
2155 let on_predicates = OnPredicates::new(oa, la, ra, sa, on.clone(), context);
2156 if !on_predicates.is_equijoin(context) {
2157 return Ok(None);
2158 }
2159
2160 // If we've gotten this far, we can do the clever thing.
2161 // We'll want to use left and right multiple times
2162 let result = left.let_in(id_gen, |id_gen, get_left| {
2163 right.let_in(id_gen, |id_gen, get_right| {
2164 // TODO: we know that we can re-use the arrangements of left and right
2165 // needed for the inner join with each of the conditional outer joins.
2166 // It is not clear whether we should hint that, or just let the planner
2167 // and optimizer run and see what happens.
2168
2169 // We'll want the inner join (minus repeated columns)
2170 let join = MirRelationExpr::join(
2171 vec![get_left.clone(), get_right.clone()],
2172 (0..oa).map(|i| vec![(0, i), (1, i)]).collect(),
2173 )
2174 // remove those columns from `right` repeating the first `oa` columns.
2175 .project(
2176 (0..(oa + la))
2177 .chain((oa + la + oa)..(oa + la + oa + ra))
2178 .collect(),
2179 )
2180 // apply the filter constraints here, to ensure nulls are not matched.
2181 .filter(on);
2182
2183 // We'll want to re-use the results of the join multiple times.
2184 join.let_in(id_gen, |id_gen, get_join| {
2185 let mut result = get_join.clone();
2186
2187 // A collection of keys present in both left and right collections.
2188 let join_keys = on_predicates.join_keys();
2189 let both_keys_arity = join_keys.len();
2190 let both_keys = get_join.restrict(join_keys).distinct();
2191
2192 // The plan is now to determine the left and right rows matched in the
2193 // inner join, subtract them from left and right respectively, pad what
2194 // remains with nulls, and fold them in to `result`.
2195
2196 both_keys.let_in(id_gen, |_id_gen, get_both| {
2197 if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter = kind {
2198 // Rows in `left` matched in the inner equijoin. This is
2199 // a semi-join between `left` and `both_keys`.
2200 let left_present = MirRelationExpr::join_scalars(
2201 vec![
2202 get_left
2203 .clone()
2204 // Push local predicates.
2205 .filter(on_predicates.lhs()),
2206 get_both.clone(),
2207 ],
2208 itertools::zip_eq(
2209 on_predicates.eq_lhs(),
2210 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + la + k)),
2211 )
2212 .map(|(l_key, b_key)| [l_key, b_key].to_vec())
2213 .collect(),
2214 )
2215 .project((0..(oa + la)).collect());
2216
2217 // Determine the types of nulls to use as filler.
2218 let right_fill = rt
2219 .into_iter()
2220 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2221 .collect();
2222
2223 // Add to `result` absent elements, filled with typed nulls.
2224 result = left_present
2225 .negate()
2226 .union(get_left.clone())
2227 .map(right_fill)
2228 .union(result);
2229 }
2230
2231 if let JoinKind::RightOuter | JoinKind::FullOuter = kind {
2232 // Rows in `right` matched in the inner equijoin. This
2233 // is a semi-join between `right` and `both_keys`.
2234 let right_present = MirRelationExpr::join_scalars(
2235 vec![
2236 get_right
2237 .clone()
2238 // Push local predicates.
2239 .filter(on_predicates.rhs()),
2240 get_both,
2241 ],
2242 itertools::zip_eq(
2243 on_predicates.eq_rhs(),
2244 (0..both_keys_arity).map(|k| MirScalarExpr::column(oa + ra + k)),
2245 )
2246 .map(|(r_key, b_key)| [r_key, b_key].to_vec())
2247 .collect(),
2248 )
2249 .project((0..(oa + ra)).collect());
2250
2251 // Determine the types of nulls to use as filler.
2252 let left_fill = lt
2253 .into_iter()
2254 .map(|typ| MirScalarExpr::literal_null(typ.scalar_type))
2255 .collect();
2256
2257 // Add to `result` absent elements, prepended with typed nulls.
2258 result = right_present
2259 .negate()
2260 .union(get_right.clone())
2261 .map(left_fill)
2262 // Permute left fill before right values.
2263 .project(
2264 itertools::chain!(
2265 0..oa, // Preserve `outer`.
2266 oa + ra..oa + la + ra, // Increment the next `la` cols by `ra`.
2267 oa..oa + ra // Decrement the next `ra` cols by `la`.
2268 )
2269 .collect(),
2270 )
2271 .union(result)
2272 }
2273
2274 Ok::<_, PlanError>(result)
2275 })
2276 })
2277 })
2278 })?;
2279 Ok(Some(result))
2280}
2281
2282/// A struct that represents the predicates in the `on` clause in a form
2283/// suitable for efficient planning outer joins with equijoin predicates.
2284struct OnPredicates {
2285 /// A store for classified `ON` predicates.
2286 ///
2287 /// Predicates that reference a single side are adjusted to assume an
2288 /// `outer × <side>` schema.
2289 predicates: Vec<OnPredicate>,
2290 /// Number of outer context columns.
2291 oa: usize,
2292}
2293
2294impl OnPredicates {
2295 const I_OUT: usize = 0; // outer context input position
2296 const I_LHS: usize = 1; // lhs input position
2297 const I_RHS: usize = 2; // rhs input position
2298 const I_SUB: usize = 3; // on subqueries input position
2299
2300 /// Classify the predicates in the `on` clause of an outer join.
2301 ///
2302 /// The other parameters are arities of the input parts:
2303 ///
2304 /// - `oa` is the arity of the `outer` context.
2305 /// - `la` is the arity of the `left` input.
2306 /// - `ra` is the arity of the `right` input.
2307 /// - `sa` is the arity of the `on` subqueries.
2308 ///
2309 /// The constructor assumes that:
2310 ///
2311 /// 1. The `on` parameter will be applied on a result that has the following
2312 /// schema `outer × left × right × on_subqueries`.
2313 /// 2. The `on` parameter is already adjusted to assume that schema.
2314 /// 3. The `on` parameter is obtained by canonicalizing the original `on:
2315 /// MirScalarExpr` with `canonicalize_predicates`.
2316 fn new(
2317 oa: usize,
2318 la: usize,
2319 ra: usize,
2320 sa: usize,
2321 on: Vec<MirScalarExpr>,
2322 _context: &Context,
2323 ) -> Self {
2324 use mz_expr::BinaryFunc::Eq;
2325
2326 // Re-bind those locally for more compact pattern matching.
2327 const I_LHS: usize = OnPredicates::I_LHS;
2328 const I_RHS: usize = OnPredicates::I_RHS;
2329
2330 // Self parameters.
2331 let mut predicates = Vec::with_capacity(on.len());
2332
2333 // Helpers for populating `predicates`.
2334 let inner_join_mapper = mz_expr::JoinInputMapper::new_from_input_arities([oa, la, ra, sa]);
2335 let rhs_permutation = itertools::chain!(0..oa + la, oa..oa + ra).collect::<Vec<_>>();
2336 let lookup_inputs = |expr: &MirScalarExpr| -> Vec<usize> {
2337 inner_join_mapper
2338 .lookup_inputs(expr)
2339 .filter(|&i| i != Self::I_OUT)
2340 .collect()
2341 };
2342 let has_subquery_refs = |expr: &MirScalarExpr| -> bool {
2343 inner_join_mapper
2344 .lookup_inputs(expr)
2345 .any(|i| i == Self::I_SUB)
2346 };
2347
2348 // Iterate over `on` elements and populate `predicates`.
2349 for mut predicate in on {
2350 if predicate.might_error() {
2351 tracing::debug!(case = "thetajoin (error)", "OnPredicates::new");
2352 // Treat predicates that can produce a literal error as Theta.
2353 predicates.push(OnPredicate::Theta(predicate));
2354 } else if has_subquery_refs(&predicate) {
2355 tracing::debug!(case = "thetajoin (subquery)", "OnPredicates::new");
2356 // Treat predicates referencing an `on` subquery as Theta.
2357 predicates.push(OnPredicate::Theta(predicate));
2358 } else if let MirScalarExpr::CallBinary {
2359 func: Eq(_),
2360 expr1,
2361 expr2,
2362 } = &mut predicate
2363 {
2364 // Obtain the non-outer inputs referenced by each side.
2365 let inputs1 = lookup_inputs(expr1);
2366 let inputs2 = lookup_inputs(expr2);
2367
2368 match (&inputs1[..], &inputs2[..]) {
2369 // Neither side references an input. This could be a
2370 // constant expression or an expression that depends only on
2371 // the outer context.
2372 ([], []) => {
2373 predicates.push(OnPredicate::Const(predicate));
2374 }
2375 // Both sides reference different inputs.
2376 ([I_LHS], [I_RHS]) => {
2377 let lhs = expr1.take();
2378 let mut rhs = expr2.take();
2379 rhs.permute(&rhs_permutation);
2380 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2381 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2382 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2383 }
2384 // Both sides reference different inputs (swapped).
2385 ([I_RHS], [I_LHS]) => {
2386 let lhs = expr2.take();
2387 let mut rhs = expr1.take();
2388 rhs.permute(&rhs_permutation);
2389 predicates.push(OnPredicate::Eq(lhs.clone(), rhs.clone()));
2390 predicates.push(OnPredicate::LhsConsequence(lhs.call_is_null().not()));
2391 predicates.push(OnPredicate::RhsConsequence(rhs.call_is_null().not()));
2392 }
2393 // Both sides reference the left input or no input.
2394 ([I_LHS], [I_LHS]) | ([I_LHS], []) | ([], [I_LHS]) => {
2395 predicates.push(OnPredicate::Lhs(predicate));
2396 }
2397 // Both sides reference the right input or no input.
2398 ([I_RHS], [I_RHS]) | ([I_RHS], []) | ([], [I_RHS]) => {
2399 predicate.permute(&rhs_permutation);
2400 predicates.push(OnPredicate::Rhs(predicate));
2401 }
2402 // At least one side references more than one input.
2403 _ => {
2404 tracing::debug!(case = "thetajoin (eq)", "OnPredicates::new");
2405 predicates.push(OnPredicate::Theta(predicate));
2406 }
2407 }
2408 } else {
2409 // Obtain the non-outer inputs referenced by this predicate.
2410 let inputs = lookup_inputs(&predicate);
2411
2412 match &inputs[..] {
2413 // The predicate references no inputs. This could be a
2414 // constant expression or an expression that depends only on
2415 // the outer context.
2416 [] => {
2417 predicates.push(OnPredicate::Const(predicate));
2418 }
2419 // The predicate references only the left input.
2420 [I_LHS] => {
2421 predicates.push(OnPredicate::Lhs(predicate));
2422 }
2423 // The predicate references only the right input.
2424 [I_RHS] => {
2425 predicate.permute(&rhs_permutation);
2426 predicates.push(OnPredicate::Rhs(predicate));
2427 }
2428 // The predicate references both inputs.
2429 _ => {
2430 tracing::debug!(case = "thetajoin (non-eq)", "OnPredicates::new");
2431 predicates.push(OnPredicate::Theta(predicate));
2432 }
2433 }
2434 }
2435 }
2436
2437 Self { predicates, oa }
2438 }
2439
2440 /// Check if the predicates can be lowered with an equijoin-based strategy.
2441 fn is_equijoin(&self, context: &Context) -> bool {
2442 // Count each `OnPredicate` variant in `self.predicates`.
2443 let (const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt) =
2444 self.predicates.iter().fold(
2445 (0, 0, 0, 0, 0, 0),
2446 |(const_cnt, lhs_cnt, rhs_cnt, eq_cnt, eq_cols, theta_cnt), p| {
2447 (
2448 const_cnt + usize::from(matches!(p, OnPredicate::Const(..))),
2449 lhs_cnt + usize::from(matches!(p, OnPredicate::Lhs(..))),
2450 rhs_cnt + usize::from(matches!(p, OnPredicate::Rhs(..))),
2451 eq_cnt + usize::from(matches!(p, OnPredicate::Eq(..))),
2452 eq_cols + usize::from(matches!(p, OnPredicate::Eq(lhs, rhs) if lhs.is_column() && rhs.is_column())),
2453 theta_cnt + usize::from(matches!(p, OnPredicate::Theta(..))),
2454 )
2455 },
2456 );
2457
2458 let is_equijion = if context.config.enable_new_outer_join_lowering {
2459 // New classifier.
2460 eq_cnt > 0 && theta_cnt == 0
2461 } else {
2462 // Old classifier.
2463 eq_cnt > 0 && eq_cnt == eq_cols && theta_cnt + const_cnt + lhs_cnt + rhs_cnt == 0
2464 };
2465
2466 // Log an entry only if this is an equijoin according to the new classifier.
2467 if eq_cnt > 0 && theta_cnt == 0 {
2468 tracing::debug!(
2469 const_cnt,
2470 lhs_cnt,
2471 rhs_cnt,
2472 eq_cnt,
2473 eq_cols,
2474 theta_cnt,
2475 "OnPredicates::is_equijoin"
2476 );
2477 }
2478
2479 is_equijion
2480 }
2481
2482 /// Return an [`MirRelationExpr`] list that represents the keys for the
2483 /// equijoin. The list will contain the outer columns as a prefix.
2484 fn join_keys(&self) -> JoinKeys {
2485 // We could return either the `lhs` or the `rhs` of the keys used to
2486 // form the inner join as they are equated by the join condition.
2487 let join_keys = self.eq_lhs().collect::<Vec<_>>();
2488
2489 if join_keys.iter().all(|k| k.is_column()) {
2490 tracing::debug!(case = "outputs", "OnPredicates::join_keys");
2491 JoinKeys::Outputs(join_keys.iter().flat_map(|k| k.as_column()).collect())
2492 } else {
2493 tracing::debug!(case = "scalars", "OnPredicates::join_keys");
2494 JoinKeys::Scalars(join_keys)
2495 }
2496 }
2497
2498 /// Return an iterator over the left-hand sides of all [`OnPredicate::Eq`]
2499 /// conditions in the predicates list.
2500 ///
2501 /// The iterator will start with column references to the outer columns as a
2502 /// prefix.
2503 fn eq_lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2504 itertools::chain(
2505 (0..self.oa).map(MirScalarExpr::column),
2506 self.predicates.iter().filter_map(|e| match e {
2507 OnPredicate::Eq(lhs, _) => Some(lhs.clone()),
2508 _ => None,
2509 }),
2510 )
2511 }
2512
2513 /// Return an iterator over the right-hand sides of all [`OnPredicate::Eq`]
2514 /// conditions in the predicates list.
2515 ///
2516 /// The iterator will start with column references to the outer columns as a
2517 /// prefix.
2518 fn eq_rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2519 itertools::chain(
2520 (0..self.oa).map(MirScalarExpr::column),
2521 self.predicates.iter().filter_map(|e| match e {
2522 OnPredicate::Eq(_, rhs) => Some(rhs.clone()),
2523 _ => None,
2524 }),
2525 )
2526 }
2527
2528 /// Return an iterator over the [`OnPredicate::Lhs`], [`OnPredicate::LhsConsequence`] and
2529 /// [`OnPredicate::Const`] conditions in the predicates list.
2530 fn lhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2531 self.predicates.iter().filter_map(|p| match p {
2532 // We treat Const predicates local to both inputs.
2533 OnPredicate::Const(p) => Some(p.clone()),
2534 OnPredicate::Lhs(p) => Some(p.clone()),
2535 OnPredicate::LhsConsequence(p) => Some(p.clone()),
2536 _ => None,
2537 })
2538 }
2539
2540 /// Return an iterator over the [`OnPredicate::Rhs`], [`OnPredicate::RhsConsequence`] and
2541 /// [`OnPredicate::Const`] conditions in the predicates list.
2542 fn rhs(&self) -> impl Iterator<Item = MirScalarExpr> + '_ {
2543 self.predicates.iter().filter_map(|p| match p {
2544 // We treat Const predicates local to both inputs.
2545 OnPredicate::Const(p) => Some(p.clone()),
2546 OnPredicate::Rhs(p) => Some(p.clone()),
2547 OnPredicate::RhsConsequence(p) => Some(p.clone()),
2548 _ => None,
2549 })
2550 }
2551}
2552
2553enum OnPredicate {
2554 // A predicate that is either constant or references only outer columns.
2555 Const(MirScalarExpr),
2556 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2557 // and possibly outer columns.
2558 //
2559 // This is one of the original predicates from the ON clause.
2560 //
2561 // One _must_ apply this predicate.
2562 Lhs(MirScalarExpr),
2563 // A local predicate on the left-hand side of the join, i.e., it references only the left input
2564 // and possibly outer columns.
2565 //
2566 // This is not one of the original predicates from the ON clause, but is just a consequence
2567 // of an original predicate in the ON clause, where the original predicate references both
2568 // inputs, but the consequence references only the left input.
2569 //
2570 // For example, the original predicate `input1.x = input2.a` has the consequence
2571 // `input1.x IS NOT NULL`. Applying such a consequence before the input is fed into the join
2572 // prevents null skew, and also makes more CSE opportunities available when the left input's key
2573 // doesn't have a NOT NULL constraint, saving us an arrangement.
2574 //
2575 // Applying the predicate is optional, because the original predicate will be applied anyway.
2576 LhsConsequence(MirScalarExpr),
2577 // A local predicate on the right-hand side of the join.
2578 //
2579 // This is one of the original predicates from the ON clause.
2580 //
2581 // One _must_ apply this predicate.
2582 Rhs(MirScalarExpr),
2583 // A consequence of an original ON predicate, see above.
2584 RhsConsequence(MirScalarExpr),
2585 // An equality predicate between the two sides.
2586 Eq(MirScalarExpr, MirScalarExpr),
2587 // a non-equality predicate between the two sides.
2588 #[allow(dead_code)]
2589 Theta(MirScalarExpr),
2590}
2591
2592/// A set of join keys referencing an input.
2593///
2594/// This is used in the [`MirRelationExpr::Join`] lowering code in order to
2595/// avoid changes (and thereby possible regressions) in plans that have equijoin
2596/// predicates consisting only of column refs.
2597///
2598/// If we were running `CanonicalizeMfp` as part of `NormalizeOps` we might be
2599/// able to get rid of this code, but as it stands `Map` simplification seems
2600/// more cumbersome than `Project` simplification, so do this just to be sure.
2601enum JoinKeys {
2602 // A predicate that is either constant or references only outer columns.
2603 Outputs(Vec<usize>),
2604 // A local predicate on the left-hand side of the join.
2605 Scalars(Vec<MirScalarExpr>),
2606}
2607
2608impl JoinKeys {
2609 fn len(&self) -> usize {
2610 match self {
2611 JoinKeys::Outputs(outputs) => outputs.len(),
2612 JoinKeys::Scalars(scalars) => scalars.len(),
2613 }
2614 }
2615}
2616
2617/// Extension methods for [`MirRelationExpr`] required in the HIR ⇒ MIR lowering
2618/// code.
2619trait LoweringExt {
2620 /// See [`MirRelationExpr::restrict`].
2621 fn restrict(self, join_keys: JoinKeys) -> Self;
2622}
2623
2624impl LoweringExt for MirRelationExpr {
2625 /// Restrict the set of columns of an input to the sequence of [`JoinKeys`].
2626 fn restrict(self, join_keys: JoinKeys) -> Self {
2627 let num_keys = join_keys.len();
2628 match join_keys {
2629 JoinKeys::Outputs(outputs) => self.project(outputs),
2630 JoinKeys::Scalars(scalars) => {
2631 let input_arity = self.arity();
2632 let outputs = (input_arity..input_arity + num_keys).collect();
2633 self.map(scalars).project(outputs)
2634 }
2635 }
2636 }
2637}