mz_transform/compound/
union.rs1use std::iter;
16
17use mz_expr::MirRelationExpr;
18use mz_expr::visit::Visit;
19use mz_repr::SqlRelationType;
20
21use crate::TransformCtx;
22
23#[derive(Debug)]
25pub struct UnionNegateFusion;
26
27impl crate::Transform for UnionNegateFusion {
28    fn name(&self) -> &'static str {
29        "UnionNegateFusion"
30    }
31
32    #[mz_ore::instrument(
33        target = "optimizer",
34        level = "debug",
35        fields(path.segment = "union_negate")
36    )]
37    fn actually_perform_transform(
38        &self,
39        relation: &mut MirRelationExpr,
40        _: &mut TransformCtx,
41    ) -> Result<(), crate::TransformError> {
42        relation.visit_mut_post(&mut Self::action)?;
43        mz_repr::explain::trace_plan(&*relation);
44        Ok(())
45    }
46}
47
48impl UnionNegateFusion {
49    pub fn action(relation: &mut MirRelationExpr) {
53        use MirRelationExpr::*;
54        if let Union { base, inputs } = relation {
55            let can_fuse = iter::once(&**base).chain(&*inputs).any(|input| -> bool {
56                match input {
57                    Union { .. } => true,
58                    Negate { input } => matches!(**input, Union { .. }),
59                    _ => false,
60                }
61            });
62            if can_fuse {
63                let mut new_inputs: Vec<MirRelationExpr> = vec![];
64                for input in iter::once(base.as_mut()).chain(inputs) {
65                    let input = input.take_dangerous();
66                    match input {
67                        Union { base, inputs } => {
68                            new_inputs.push(*base);
69                            new_inputs.extend(inputs);
70                        }
71                        Negate { input } if matches!(*input, Union { .. }) => {
72                            if let Union { base, inputs } = *input {
73                                new_inputs.push(base.negate());
74                                new_inputs.extend(inputs.into_iter().map(|x| x.negate()));
75                            } else {
76                                unreachable!()
77                            }
78                        }
79                        _ => new_inputs.push(input),
80                    }
81                }
82
83                for new_input in new_inputs.iter_mut() {
85                    crate::fusion::negate::Negate::action(new_input);
86                }
87
88                assert!(!new_inputs.is_empty());
91                *relation = MirRelationExpr::union_many(new_inputs, SqlRelationType::empty());
92            }
93        }
94    }
95}