mz_transform/canonicalization/
flatmap_to_map.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Turns `FlatMap` into `Map` if only one row is produced by flatmap.
11//!
12
13use mz_expr::visit::Visit;
14use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc};
15use mz_repr::{Datum, Diff, ScalarType};
16
17use crate::TransformCtx;
18
19/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
20#[derive(Debug)]
21pub struct FlatMapElimination;
22
23impl crate::Transform for FlatMapElimination {
24    fn name(&self) -> &'static str {
25        "FlatMapElimination"
26    }
27
28    #[mz_ore::instrument(
29        target = "optimizer",
30        level = "debug",
31        fields(path.segment = "flatmap_to_map")
32    )]
33    fn actually_perform_transform(
34        &self,
35        relation: &mut MirRelationExpr,
36        _: &mut TransformCtx,
37    ) -> Result<(), crate::TransformError> {
38        relation.visit_mut_post(&mut Self::action)?;
39        mz_repr::explain::trace_plan(&*relation);
40        Ok(())
41    }
42}
43
44impl FlatMapElimination {
45    /// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
46    pub fn action(relation: &mut MirRelationExpr) {
47        if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
48            let (func, with_ordinality) = if let TableFunc::WithOrdinality { inner } = func {
49                // get to the actual function, but remember that we have a WITH ORDINALITY clause.
50                (&**inner, true)
51            } else {
52                (&*func, false)
53            };
54
55            if let TableFunc::GuardSubquerySize { .. } = func {
56                // (`with_ordinality` doesn't matter because this function never emits rows)
57                if let Some(1) = exprs[0].as_literal_int64() {
58                    relation.take_safely(None);
59                }
60            } else if let TableFunc::Wrap { width, .. } = func {
61                if *width >= exprs.len() {
62                    *relation = input.take_dangerous().map(std::mem::take(exprs));
63                    if with_ordinality {
64                        *relation = relation.take_dangerous().map_one(MirScalarExpr::literal(
65                            Ok(Datum::Int64(1)),
66                            ScalarType::Int64,
67                        ));
68                    }
69                }
70            } else if is_supported_unnest(func) {
71                let func = func.clone();
72                let exprs = exprs.clone();
73                use mz_expr::MirScalarExpr;
74                use mz_repr::RowArena;
75                if let MirScalarExpr::Literal(Ok(row), ..) = &exprs[0] {
76                    let temp_storage = RowArena::default();
77                    if let Ok(mut iter) = func.eval(&[row.iter().next().unwrap()], &temp_storage) {
78                        match (iter.next(), iter.next()) {
79                            (None, _) => {
80                                // If there are no elements in the literal argument, no output.
81                                relation.take_safely(None);
82                            }
83                            (Some((row, Diff::ONE)), None) => {
84                                assert_eq!(func.output_type().column_types.len(), 1);
85                                *relation =
86                                    input.take_dangerous().map(vec![MirScalarExpr::Literal(
87                                        Ok(row),
88                                        func.output_type().column_types[0].clone(),
89                                    )]);
90                                if with_ordinality {
91                                    *relation =
92                                        relation.take_dangerous().map_one(MirScalarExpr::literal(
93                                            Ok(Datum::Int64(1)),
94                                            ScalarType::Int64,
95                                        ));
96                                }
97                            }
98                            _ => {}
99                        }
100                    };
101                }
102            }
103        }
104    }
105}
106
107/// Returns `true` for `unnest_~` variants supported by [`FlatMapElimination`].
108fn is_supported_unnest(func: &TableFunc) -> bool {
109    use TableFunc::*;
110    matches!(func, UnnestArray { .. } | UnnestList { .. })
111}