mz_transform/canonicalization/
flatmap_to_map.rs1use mz_expr::visit::Visit;
14use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc};
15use mz_repr::{Datum, Diff, ScalarType};
16
17use crate::TransformCtx;
18
19#[derive(Debug)]
21pub struct FlatMapElimination;
22
23impl crate::Transform for FlatMapElimination {
24 fn name(&self) -> &'static str {
25 "FlatMapElimination"
26 }
27
28 #[mz_ore::instrument(
29 target = "optimizer",
30 level = "debug",
31 fields(path.segment = "flatmap_to_map")
32 )]
33 fn actually_perform_transform(
34 &self,
35 relation: &mut MirRelationExpr,
36 _: &mut TransformCtx,
37 ) -> Result<(), crate::TransformError> {
38 relation.visit_mut_post(&mut Self::action)?;
39 mz_repr::explain::trace_plan(&*relation);
40 Ok(())
41 }
42}
43
44impl FlatMapElimination {
45 pub fn action(relation: &mut MirRelationExpr) {
47 if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
48 let (func, with_ordinality) = if let TableFunc::WithOrdinality { inner } = func {
49 (&**inner, true)
51 } else {
52 (&*func, false)
53 };
54
55 if let TableFunc::GuardSubquerySize { .. } = func {
56 if let Some(1) = exprs[0].as_literal_int64() {
58 relation.take_safely(None);
59 }
60 } else if let TableFunc::Wrap { width, .. } = func {
61 if *width >= exprs.len() {
62 *relation = input.take_dangerous().map(std::mem::take(exprs));
63 if with_ordinality {
64 *relation = relation.take_dangerous().map_one(MirScalarExpr::literal(
65 Ok(Datum::Int64(1)),
66 ScalarType::Int64,
67 ));
68 }
69 }
70 } else if is_supported_unnest(func) {
71 let func = func.clone();
72 let exprs = exprs.clone();
73 use mz_expr::MirScalarExpr;
74 use mz_repr::RowArena;
75 if let MirScalarExpr::Literal(Ok(row), ..) = &exprs[0] {
76 let temp_storage = RowArena::default();
77 if let Ok(mut iter) = func.eval(&[row.iter().next().unwrap()], &temp_storage) {
78 match (iter.next(), iter.next()) {
79 (None, _) => {
80 relation.take_safely(None);
82 }
83 (Some((row, Diff::ONE)), None) => {
84 assert_eq!(func.output_type().column_types.len(), 1);
85 *relation =
86 input.take_dangerous().map(vec![MirScalarExpr::Literal(
87 Ok(row),
88 func.output_type().column_types[0].clone(),
89 )]);
90 if with_ordinality {
91 *relation =
92 relation.take_dangerous().map_one(MirScalarExpr::literal(
93 Ok(Datum::Int64(1)),
94 ScalarType::Int64,
95 ));
96 }
97 }
98 _ => {}
99 }
100 };
101 }
102 }
103 }
104 }
105}
106
107fn is_supported_unnest(func: &TableFunc) -> bool {
109 use TableFunc::*;
110 matches!(func, UnnestArray { .. } | UnnestList { .. })
111}