mz_transform/canonicalization/flat_map_elimination.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! For a `FlatMap` where the table function's arguments are all constants, turns it into `Map` if
11//! only 1 row is produced by the table function, or turns it into an empty constant collection if 0
12//! rows are produced by the table function.
13//!
14//! It does an additional optimization on the `Wrap` table function: when `Wrap`'s width is larger
15//! than its number of arguments, it removes the `FlatMap Wrap ...`, because such `Wrap`s would have
16//! no effect.
17
18use itertools::Itertools;
19use mz_expr::visit::Visit;
20use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc};
21use mz_repr::{Diff, Row, RowArena};
22
23use crate::TransformCtx;
24
25/// Attempts to eliminate FlatMaps that are sure to have 0 or 1 results on each input row.
26#[derive(Debug)]
27pub struct FlatMapElimination;
28
29impl crate::Transform for FlatMapElimination {
30 fn name(&self) -> &'static str {
31 "FlatMapElimination"
32 }
33
34 #[mz_ore::instrument(
35 target = "optimizer",
36 level = "debug",
37 fields(path.segment = "flat_map_elimination")
38 )]
39 fn actually_perform_transform(
40 &self,
41 relation: &mut MirRelationExpr,
42 _: &mut TransformCtx,
43 ) -> Result<(), crate::TransformError> {
44 relation.visit_mut_post(&mut Self::action)?;
45 mz_repr::explain::trace_plan(&*relation);
46 Ok(())
47 }
48}
49
50impl FlatMapElimination {
51 /// Apply `FlatMapElimination` to the root of the given `MirRelationExpr`.
52 pub fn action(relation: &mut MirRelationExpr) {
53 // Treat Wrap specially: we can sometimes optimize it out even when it has non-literal
54 // arguments.
55 //
56 // (No need to look for WithOrdinality here, as that never occurs with Wrap: users can't
57 // call Wrap directly; we only create calls to Wrap ourselves, and we don't use
58 // WithOrdinality on it.)
59 if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
60 if let TableFunc::Wrap { width, .. } = func {
61 if *width >= exprs.len() {
62 *relation = input.take_dangerous().map(std::mem::take(exprs));
63 }
64 }
65 }
66 // For all other table functions (and Wraps that are not covered by the above), check
67 // whether all arguments are literals (with no errors), in which case we'll evaluate the
68 // table function and check how many output rows it has, and maybe turn the FlatMap into
69 // something simpler.
70 if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
71 if let Some(args) = exprs
72 .iter()
73 .map(|e| e.as_literal_non_error())
74 .collect::<Option<Vec<_>>>()
75 {
76 let temp_storage = RowArena::new();
77 let (first, second) = match func.eval(&args, &temp_storage) {
78 Ok(mut r) => (r.next(), r.next()),
79 // don't play with errors
80 Err(_) => return,
81 };
82 match (first, second) {
83 // The table function evaluated to an empty collection.
84 (None, _) => {
85 relation.take_safely(None);
86 }
87 // The table function evaluated to a collection with exactly 1 row.
88 (Some((first_row, Diff::ONE)), None) => {
89 let types = func.output_type().column_types;
90 let map_exprs = first_row
91 .into_iter()
92 .zip_eq(types)
93 .map(|(d, typ)| MirScalarExpr::Literal(Ok(Row::pack_slice(&[d])), typ))
94 .collect();
95 *relation = input.take_dangerous().map(map_exprs);
96 }
97 // The table function evaluated to a collection with more than 1 row; nothing to do.
98 _ => {}
99 }
100 }
101 }
102 }
103}