1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
910//! Turns `FlatMap` into `Map` if only one row is produced by flatmap.
11//!
1213use mz_expr::visit::Visit;
14use mz_expr::{MirRelationExpr, TableFunc};
15use mz_repr::Diff;
1617use crate::TransformCtx;
1819/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
20#[derive(Debug)]
21pub struct FlatMapToMap;
2223impl crate::Transform for FlatMapToMap {
24fn name(&self) -> &'static str {
25"FlatMapToMap"
26}
2728#[mz_ore::instrument(
29 target = "optimizer",
30 level = "debug",
31 fields(path.segment = "flatmap_to_map")
32 )]
33fn actually_perform_transform(
34&self,
35 relation: &mut MirRelationExpr,
36_: &mut TransformCtx,
37 ) -> Result<(), crate::TransformError> {
38 relation.visit_mut_post(&mut Self::action)?;
39 mz_repr::explain::trace_plan(&*relation);
40Ok(())
41 }
42}
4344impl FlatMapToMap {
45/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
46pub fn action(relation: &mut MirRelationExpr) {
47if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
48if let TableFunc::Wrap { width, .. } = func {
49if *width >= exprs.len() {
50*relation = input.take_dangerous().map(std::mem::take(exprs));
51 }
52 } else if is_supported_unnest(func) {
53let func = func.clone();
54let exprs = exprs.clone();
55use mz_expr::MirScalarExpr;
56use mz_repr::RowArena;
57if let MirScalarExpr::Literal(Ok(row), ..) = &exprs[0] {
58let temp_storage = RowArena::default();
59if let Ok(mut iter) = func.eval(&[row.iter().next().unwrap()], &temp_storage) {
60match (iter.next(), iter.next()) {
61 (None, _) => {
62// If there are no elements in the literal argument, no output.
63relation.take_safely(None);
64 }
65 (Some((row, Diff::ONE)), None) => {
66*relation =
67 input.take_dangerous().map(vec![MirScalarExpr::Literal(
68Ok(row),
69 func.output_type().column_types[0].clone(),
70 )]);
71 }
72_ => {}
73 }
74 };
75 }
76 }
77 }
78 }
79}
8081/// Returns `true` for `unnest_~` variants supported by [`FlatMapToMap`].
82fn is_supported_unnest(func: &TableFunc) -> bool {
83use TableFunc::*;
84matches!(func, UnnestArray { .. } | UnnestList { .. })
85}