mz_transform/canonicalization/
flat_map_elimination.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! For a `FlatMap` where the table function's arguments are all constants, turns it into `Map` if
11//! only 1 row is produced by the table function, or turns it into an empty constant collection if 0
12//! rows are produced by the table function.
13//!
14//! It does an additional optimization on the `Wrap` table function: when `Wrap`'s width is larger
15//! than its number of arguments, it removes the `FlatMap Wrap ...`, because such `Wrap`s would have
16//! no effect.
17
18use itertools::Itertools;
19use mz_expr::visit::Visit;
20use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc};
21use mz_repr::{Diff, Row, RowArena};
22
23use crate::TransformCtx;
24
25/// Attempts to eliminate FlatMaps that are sure to have 0 or 1 results on each input row.
26#[derive(Debug)]
27pub struct FlatMapElimination;
28
29impl crate::Transform for FlatMapElimination {
30    fn name(&self) -> &'static str {
31        "FlatMapElimination"
32    }
33
34    #[mz_ore::instrument(
35        target = "optimizer",
36        level = "debug",
37        fields(path.segment = "flat_map_elimination")
38    )]
39    fn actually_perform_transform(
40        &self,
41        relation: &mut MirRelationExpr,
42        _: &mut TransformCtx,
43    ) -> Result<(), crate::TransformError> {
44        relation.visit_mut_post(&mut Self::action)?;
45        mz_repr::explain::trace_plan(&*relation);
46        Ok(())
47    }
48}
49
50impl FlatMapElimination {
51    /// Apply `FlatMapElimination` to the root of the given `MirRelationExpr`.
52    pub fn action(relation: &mut MirRelationExpr) {
53        // Treat Wrap specially: we can sometimes optimize it out even when it has non-literal
54        // arguments.
55        //
56        // (No need to look for WithOrdinality here, as that never occurs with Wrap: users can't
57        // call Wrap directly; we only create calls to Wrap ourselves, and we don't use
58        // WithOrdinality on it.)
59        if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
60            if let TableFunc::Wrap { width, .. } = func {
61                if *width >= exprs.len() {
62                    *relation = input.take_dangerous().map(std::mem::take(exprs));
63                }
64            }
65        }
66        // For all other table functions (and Wraps that are not covered by the above), check
67        // whether all arguments are literals (with no errors), in which case we'll evaluate the
68        // table function and check how many output rows it has, and maybe turn the FlatMap into
69        // something simpler.
70        if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
71            if let Some(args) = exprs
72                .iter()
73                .map(|e| e.as_literal_non_error())
74                .collect::<Option<Vec<_>>>()
75            {
76                let temp_storage = RowArena::new();
77                let (first, second) = match func.eval(&args, &temp_storage) {
78                    Ok(mut r) => (r.next(), r.next()),
79                    // don't play with errors
80                    Err(_) => return,
81                };
82                match (first, second) {
83                    // The table function evaluated to an empty collection.
84                    (None, _) => {
85                        relation.take_safely(None);
86                    }
87                    // The table function evaluated to a collection with exactly 1 row.
88                    (Some((first_row, Diff::ONE)), None) => {
89                        let types = func.output_type().column_types;
90                        let map_exprs = first_row
91                            .into_iter()
92                            .zip_eq(types)
93                            .map(|(d, typ)| MirScalarExpr::Literal(Ok(Row::pack_slice(&[d])), typ))
94                            .collect();
95                        *relation = input.take_dangerous().map(map_exprs);
96                    }
97                    // The table function evaluated to a collection with more than 1 row; nothing to do.
98                    _ => {}
99                }
100            }
101        }
102    }
103}