mz_transform/compound/
union.rs1use std::iter;
16
17use mz_expr::MirRelationExpr;
18use mz_expr::visit::Visit;
19use mz_repr::SqlRelationType;
20
21use crate::TransformCtx;
22
23#[derive(Debug)]
25pub struct UnionNegateFusion;
26
27impl crate::Transform for UnionNegateFusion {
28 fn name(&self) -> &'static str {
29 "UnionNegateFusion"
30 }
31
32 #[mz_ore::instrument(
33 target = "optimizer",
34 level = "debug",
35 fields(path.segment = "union_negate")
36 )]
37 fn actually_perform_transform(
38 &self,
39 relation: &mut MirRelationExpr,
40 _: &mut TransformCtx,
41 ) -> Result<(), crate::TransformError> {
42 relation.visit_mut_post(&mut Self::action)?;
43 mz_repr::explain::trace_plan(&*relation);
44 Ok(())
45 }
46}
47
48impl UnionNegateFusion {
49 pub fn action(relation: &mut MirRelationExpr) {
53 use MirRelationExpr::*;
54 if let Union { base, inputs } = relation {
55 let can_fuse = iter::once(&**base).chain(&*inputs).any(|input| -> bool {
56 match input {
57 Union { .. } => true,
58 Negate { input } => matches!(**input, Union { .. }),
59 _ => false,
60 }
61 });
62 if can_fuse {
63 let mut new_inputs: Vec<MirRelationExpr> = vec![];
64 for input in iter::once(base.as_mut()).chain(inputs) {
65 let input = input.take_dangerous();
66 match input {
67 Union { base, inputs } => {
68 new_inputs.push(*base);
69 new_inputs.extend(inputs);
70 }
71 Negate { input } if matches!(*input, Union { .. }) => {
72 if let Union { base, inputs } = *input {
73 new_inputs.push(base.negate());
74 new_inputs.extend(inputs.into_iter().map(|x| x.negate()));
75 } else {
76 unreachable!()
77 }
78 }
79 _ => new_inputs.push(input),
80 }
81 }
82
83 for new_input in new_inputs.iter_mut() {
85 crate::fusion::negate::Negate::action(new_input);
86 }
87
88 assert!(!new_inputs.is_empty());
91 *relation = MirRelationExpr::union_many(new_inputs, SqlRelationType::empty());
92 }
93 }
94 }
95}