mz_transform/union_cancel.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Detects an input being unioned with its negation and cancels them out
11
12use itertools::Itertools;
13use mz_expr::MirRelationExpr;
14use mz_expr::visit::Visit;
15
16use crate::{TransformCtx, TransformError};
17
18/// Detects an input being unioned with its negation and cancels them out
19///
20/// `UnionBranchCancellation` is recursion-safe, but this is not immediately trivial:
21/// It relies on the equality of certain `MirRelationExpr`s, which is a scary thing in a WMR
22/// context, because `Get x` can mean not-equal things in different Let bindings.
23/// However, this problematic case can't happen here, because when `UnionBranchCancellation` is
24/// looking at two Gets to the same Id, then these have to be under the same Let binding.
25/// This is because the recursion of `compare_branches` starts from two things in the same Let
26/// binding (from two inputs of a Union), and then we don't make any `compare_branches` call
27/// where `relation` and `other` are in different Let bindings.
28#[derive(Debug)]
29pub struct UnionBranchCancellation;
30
31impl crate::Transform for UnionBranchCancellation {
32 fn name(&self) -> &'static str {
33 "UnionBranchCancellation"
34 }
35
36 #[mz_ore::instrument(
37 target = "optimizer",
38 level = "debug",
39 fields(path.segment = "union_branch_cancellation")
40 )]
41 fn actually_perform_transform(
42 &self,
43 relation: &mut MirRelationExpr,
44 _: &mut TransformCtx,
45 ) -> Result<(), TransformError> {
46 let result = relation.try_visit_mut_post(&mut |e| self.action(e));
47 mz_repr::explain::trace_plan(&*relation);
48 result
49 }
50}
51
52/// Result of the comparison of two branches of a union for cancellation
53/// purposes.
54enum BranchCmp {
55 /// The two branches are equivalent in the sense the produce the
56 /// same exact results.
57 Equivalent,
58 /// The two branches are equivalent, but one of them produces negated
59 /// row count values, and hence, they cancel each other.
60 Inverse,
61 /// The two branches are not equivalent in any way.
62 Distinct,
63}
64
65impl BranchCmp {
66 /// Modify the result of the comparison when a Negate operator is
67 /// found at the top of one the branches just compared.
68 fn inverse(self) -> Self {
69 match self {
70 BranchCmp::Equivalent => BranchCmp::Inverse,
71 BranchCmp::Inverse => BranchCmp::Equivalent,
72 BranchCmp::Distinct => BranchCmp::Distinct,
73 }
74 }
75}
76
77impl UnionBranchCancellation {
78 /// Detects an input being unioned with its negation and cancels them out
79 pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), TransformError> {
80 if let MirRelationExpr::Union { base, inputs } = relation {
81 // Compares a union branch against the remaining branches in the union
82 // with opposite sign until it finds a branch that cancels the given one
83 // and returns its position.
84 let matching_negation = |input: &MirRelationExpr,
85 sign: bool,
86 inputs: &[MirRelationExpr],
87 input_signs: &[bool],
88 start_idx: usize|
89 -> Option<usize> {
90 for i in start_idx..inputs.len() {
91 // Only compare branches with opposite signs
92 if sign != input_signs[i] {
93 if let BranchCmp::Inverse = Self::compare_branches(input, &inputs[i]) {
94 return Some(i);
95 }
96 }
97 }
98 None
99 };
100
101 let base_sign = Self::branch_sign(base);
102 let input_signs = inputs.iter().map(Self::branch_sign).collect_vec();
103
104 // Compare branches if there is at least a negated branch
105 if std::iter::once(&base_sign).chain(&input_signs).any(|x| *x) {
106 if let Some(j) = matching_negation(&*base, base_sign, inputs, &input_signs, 0) {
107 let relation_typ = base.typ();
108 **base = MirRelationExpr::constant(vec![], relation_typ.clone());
109 inputs[j] = MirRelationExpr::constant(vec![], relation_typ);
110 }
111
112 for i in 0..inputs.len() {
113 if let Some(j) =
114 matching_negation(&inputs[i], input_signs[i], inputs, &input_signs, i + 1)
115 {
116 let relation_typ = inputs[i].typ();
117 inputs[i] = MirRelationExpr::constant(vec![], relation_typ.clone());
118 inputs[j] = MirRelationExpr::constant(vec![], relation_typ);
119 }
120 }
121 }
122 }
123 Ok(())
124 }
125
126 /// Returns the sign of a given union branch. The sign is `true` if the branch contains
127 /// an odd number of Negate operators within a chain of Map, Filter and Project
128 /// operators, and `false` otherwise.
129 ///
130 /// This sign is pre-computed for all union branches in order to avoid performing
131 /// expensive comparisons of branches with the same sign since they can't possibly
132 /// cancel each other.
133 fn branch_sign(branch: &MirRelationExpr) -> bool {
134 let mut relation = branch;
135 let mut sign = false;
136 loop {
137 match relation {
138 MirRelationExpr::Negate { input } => {
139 sign ^= true;
140 relation = &**input;
141 }
142 MirRelationExpr::Map { input, .. }
143 | MirRelationExpr::Filter { input, .. }
144 | MirRelationExpr::Project { input, .. } => {
145 relation = &**input;
146 }
147 _ => return sign,
148 }
149 }
150 }
151
152 /// Compares two branches to check whether they produce the same results but
153 /// with negated row count values, ie. one of them contains an extra Negate operator.
154 /// Negate operators may appear interleaved with Map, Filter and Project
155 /// operators, but these operators must appear in the same order in both branches.
156 fn compare_branches(relation: &MirRelationExpr, other: &MirRelationExpr) -> BranchCmp {
157 match (relation, other) {
158 (
159 MirRelationExpr::Negate { input: input1 },
160 MirRelationExpr::Negate { input: input2 },
161 ) => Self::compare_branches(&*input1, &*input2),
162 (r, MirRelationExpr::Negate { input }) | (MirRelationExpr::Negate { input }, r) => {
163 Self::compare_branches(&*input, r).inverse()
164 }
165 (
166 MirRelationExpr::Map {
167 input: input1,
168 scalars: scalars1,
169 },
170 MirRelationExpr::Map {
171 input: input2,
172 scalars: scalars2,
173 },
174 ) => {
175 if scalars1 == scalars2 {
176 Self::compare_branches(&*input1, &*input2)
177 } else {
178 BranchCmp::Distinct
179 }
180 }
181 (
182 MirRelationExpr::Filter {
183 input: input1,
184 predicates: predicates1,
185 },
186 MirRelationExpr::Filter {
187 input: input2,
188 predicates: predicates2,
189 },
190 ) => {
191 if predicates1 == predicates2 {
192 Self::compare_branches(&*input1, &*input2)
193 } else {
194 BranchCmp::Distinct
195 }
196 }
197 (
198 MirRelationExpr::Project {
199 input: input1,
200 outputs: outputs1,
201 },
202 MirRelationExpr::Project {
203 input: input2,
204 outputs: outputs2,
205 },
206 ) => {
207 if outputs1 == outputs2 {
208 Self::compare_branches(&*input1, &*input2)
209 } else {
210 BranchCmp::Distinct
211 }
212 }
213 _ => {
214 if relation == other {
215 BranchCmp::Equivalent
216 } else {
217 BranchCmp::Distinct
218 }
219 }
220 }
221 }
222}