mz_transform/
union_cancel.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Detects an input being unioned with its negation and cancels them out
11
12use itertools::Itertools;
13use mz_expr::MirRelationExpr;
14use mz_expr::visit::Visit;
15
16use crate::{TransformCtx, TransformError};
17
18/// Detects an input being unioned with its negation and cancels them out
19///
20/// `UnionBranchCancellation` is recursion-safe, but this is not immediately trivial:
21/// It relies on the equality of certain `MirRelationExpr`s, which is a scary thing in a WMR
22/// context, because `Get x` can mean not-equal things in different Let bindings.
23/// However, this problematic case can't happen here, because when `UnionBranchCancellation` is
24/// looking at two Gets to the same Id, then these have to be under the same Let binding.
25/// This is because the recursion of `compare_branches` starts from two things in the same Let
26/// binding (from two inputs of a Union), and then we don't make any `compare_branches` call
27/// where `relation` and `other` are in different Let bindings.
28#[derive(Debug)]
29pub struct UnionBranchCancellation;
30
31impl crate::Transform for UnionBranchCancellation {
32    fn name(&self) -> &'static str {
33        "UnionBranchCancellation"
34    }
35
36    #[mz_ore::instrument(
37        target = "optimizer",
38        level = "debug",
39        fields(path.segment = "union_branch_cancellation")
40    )]
41    fn actually_perform_transform(
42        &self,
43        relation: &mut MirRelationExpr,
44        _: &mut TransformCtx,
45    ) -> Result<(), TransformError> {
46        let result = relation.try_visit_mut_post(&mut |e| self.action(e));
47        mz_repr::explain::trace_plan(&*relation);
48        result
49    }
50}
51
52/// Result of the comparison of two branches of a union for cancellation
53/// purposes.
54enum BranchCmp {
55    /// The two branches are equivalent in the sense the produce the
56    /// same exact results.
57    Equivalent,
58    /// The two branches are equivalent, but one of them produces negated
59    /// row count values, and hence, they cancel each other.
60    Inverse,
61    /// The two branches are not equivalent in any way.
62    Distinct,
63}
64
65impl BranchCmp {
66    /// Modify the result of the comparison when a Negate operator is
67    /// found at the top of one the branches just compared.
68    fn inverse(self) -> Self {
69        match self {
70            BranchCmp::Equivalent => BranchCmp::Inverse,
71            BranchCmp::Inverse => BranchCmp::Equivalent,
72            BranchCmp::Distinct => BranchCmp::Distinct,
73        }
74    }
75}
76
77impl UnionBranchCancellation {
78    /// Detects an input being unioned with its negation and cancels them out
79    pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), TransformError> {
80        if let MirRelationExpr::Union { base, inputs } = relation {
81            // Compares a union branch against the remaining branches in the union
82            // with opposite sign until it finds a branch that cancels the given one
83            // and returns its position.
84            let matching_negation = |input: &MirRelationExpr,
85                                     sign: bool,
86                                     inputs: &[MirRelationExpr],
87                                     input_signs: &[bool],
88                                     start_idx: usize|
89             -> Option<usize> {
90                for i in start_idx..inputs.len() {
91                    // Only compare branches with opposite signs
92                    if sign != input_signs[i] {
93                        if let BranchCmp::Inverse = Self::compare_branches(input, &inputs[i]) {
94                            return Some(i);
95                        }
96                    }
97                }
98                None
99            };
100
101            let base_sign = Self::branch_sign(base);
102            let input_signs = inputs.iter().map(Self::branch_sign).collect_vec();
103
104            // Compare branches if there is at least a negated branch
105            if std::iter::once(&base_sign).chain(&input_signs).any(|x| *x) {
106                if let Some(j) = matching_negation(&*base, base_sign, inputs, &input_signs, 0) {
107                    let relation_typ = base.typ();
108                    **base = MirRelationExpr::constant(vec![], relation_typ.clone());
109                    inputs[j] = MirRelationExpr::constant(vec![], relation_typ);
110                }
111
112                for i in 0..inputs.len() {
113                    if let Some(j) =
114                        matching_negation(&inputs[i], input_signs[i], inputs, &input_signs, i + 1)
115                    {
116                        let relation_typ = inputs[i].typ();
117                        inputs[i] = MirRelationExpr::constant(vec![], relation_typ.clone());
118                        inputs[j] = MirRelationExpr::constant(vec![], relation_typ);
119                    }
120                }
121            }
122        }
123        Ok(())
124    }
125
126    /// Returns the sign of a given union branch. The sign is `true` if the branch contains
127    /// an odd number of Negate operators within a chain of Map, Filter and Project
128    /// operators, and `false` otherwise.
129    ///
130    /// This sign is pre-computed for all union branches in order to avoid performing
131    /// expensive comparisons of branches with the same sign since they can't possibly
132    /// cancel each other.
133    fn branch_sign(branch: &MirRelationExpr) -> bool {
134        let mut relation = branch;
135        let mut sign = false;
136        loop {
137            match relation {
138                MirRelationExpr::Negate { input } => {
139                    sign ^= true;
140                    relation = &**input;
141                }
142                MirRelationExpr::Map { input, .. }
143                | MirRelationExpr::Filter { input, .. }
144                | MirRelationExpr::Project { input, .. } => {
145                    relation = &**input;
146                }
147                _ => return sign,
148            }
149        }
150    }
151
152    /// Compares two branches to check whether they produce the same results but
153    /// with negated row count values, ie. one of them contains an extra Negate operator.
154    /// Negate operators may appear interleaved with Map, Filter and Project
155    /// operators, but these operators must appear in the same order in both branches.
156    fn compare_branches(relation: &MirRelationExpr, other: &MirRelationExpr) -> BranchCmp {
157        match (relation, other) {
158            (
159                MirRelationExpr::Negate { input: input1 },
160                MirRelationExpr::Negate { input: input2 },
161            ) => Self::compare_branches(&*input1, &*input2),
162            (r, MirRelationExpr::Negate { input }) | (MirRelationExpr::Negate { input }, r) => {
163                Self::compare_branches(&*input, r).inverse()
164            }
165            (
166                MirRelationExpr::Map {
167                    input: input1,
168                    scalars: scalars1,
169                },
170                MirRelationExpr::Map {
171                    input: input2,
172                    scalars: scalars2,
173                },
174            ) => {
175                if scalars1 == scalars2 {
176                    Self::compare_branches(&*input1, &*input2)
177                } else {
178                    BranchCmp::Distinct
179                }
180            }
181            (
182                MirRelationExpr::Filter {
183                    input: input1,
184                    predicates: predicates1,
185                },
186                MirRelationExpr::Filter {
187                    input: input2,
188                    predicates: predicates2,
189                },
190            ) => {
191                if predicates1 == predicates2 {
192                    Self::compare_branches(&*input1, &*input2)
193                } else {
194                    BranchCmp::Distinct
195                }
196            }
197            (
198                MirRelationExpr::Project {
199                    input: input1,
200                    outputs: outputs1,
201                },
202                MirRelationExpr::Project {
203                    input: input2,
204                    outputs: outputs2,
205                },
206            ) => {
207                if outputs1 == outputs2 {
208                    Self::compare_branches(&*input1, &*input2)
209                } else {
210                    BranchCmp::Distinct
211                }
212            }
213            _ => {
214                if relation == other {
215                    BranchCmp::Equivalent
216                } else {
217                    BranchCmp::Distinct
218                }
219            }
220        }
221    }
222}