mz_transform/canonicalization/
flatmap_to_map.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Turns `FlatMap` into `Map` if only one row is produced by flatmap.
11//!
12
13use mz_expr::visit::Visit;
14use mz_expr::{MirRelationExpr, TableFunc};
15use mz_repr::Diff;
16
17use crate::TransformCtx;
18
19/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
20#[derive(Debug)]
21pub struct FlatMapToMap;
22
23impl crate::Transform for FlatMapToMap {
24    fn name(&self) -> &'static str {
25        "FlatMapToMap"
26    }
27
28    #[mz_ore::instrument(
29        target = "optimizer",
30        level = "debug",
31        fields(path.segment = "flatmap_to_map")
32    )]
33    fn actually_perform_transform(
34        &self,
35        relation: &mut MirRelationExpr,
36        _: &mut TransformCtx,
37    ) -> Result<(), crate::TransformError> {
38        relation.visit_mut_post(&mut Self::action)?;
39        mz_repr::explain::trace_plan(&*relation);
40        Ok(())
41    }
42}
43
44impl FlatMapToMap {
45    /// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
46    pub fn action(relation: &mut MirRelationExpr) {
47        if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
48            if let TableFunc::Wrap { width, .. } = func {
49                if *width >= exprs.len() {
50                    *relation = input.take_dangerous().map(std::mem::take(exprs));
51                }
52            } else if is_supported_unnest(func) {
53                let func = func.clone();
54                let exprs = exprs.clone();
55                use mz_expr::MirScalarExpr;
56                use mz_repr::RowArena;
57                if let MirScalarExpr::Literal(Ok(row), ..) = &exprs[0] {
58                    let temp_storage = RowArena::default();
59                    if let Ok(mut iter) = func.eval(&[row.iter().next().unwrap()], &temp_storage) {
60                        match (iter.next(), iter.next()) {
61                            (None, _) => {
62                                // If there are no elements in the literal argument, no output.
63                                relation.take_safely(None);
64                            }
65                            (Some((row, Diff::ONE)), None) => {
66                                *relation =
67                                    input.take_dangerous().map(vec![MirScalarExpr::Literal(
68                                        Ok(row),
69                                        func.output_type().column_types[0].clone(),
70                                    )]);
71                            }
72                            _ => {}
73                        }
74                    };
75                }
76            }
77        }
78    }
79}
80
81/// Returns `true` for `unnest_~` variants supported by [`FlatMapToMap`].
82fn is_supported_unnest(func: &TableFunc) -> bool {
83    use TableFunc::*;
84    matches!(func, UnnestArray { .. } | UnnestList { .. })
85}