mz_transform/fusion/top_k.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Fuses a sequence of `TopK` operators in to one `TopK` operator
11
12use mz_expr::MirRelationExpr;
13
14use crate::TransformCtx;
15
16/// Fuses a sequence of `TopK` operators in to one `TopK` operator if
17/// they happen to share the same grouping and ordering key.
18#[derive(Debug)]
19pub struct TopK;
20
21impl crate::Transform for TopK {
22 fn name(&self) -> &'static str {
23 "TopKFusion"
24 }
25
26 #[mz_ore::instrument(
27 target = "optimizer",
28 level = "debug",
29 fields(path.segment = "topk_fusion")
30 )]
31 fn actually_perform_transform(
32 &self,
33 relation: &mut MirRelationExpr,
34 _: &mut TransformCtx,
35 ) -> Result<(), crate::TransformError> {
36 relation.visit_pre_mut(&mut Self::action);
37 mz_repr::explain::trace_plan(&*relation);
38 Ok(())
39 }
40}
41
42impl TopK {
43 /// Fuses a sequence of `TopK` operators in to one `TopK` operator.
44 pub fn action(relation: &mut MirRelationExpr) {
45 if let MirRelationExpr::TopK {
46 input,
47 group_key,
48 order_key,
49 limit,
50 offset,
51 monotonic,
52 expected_group_size,
53 } = relation
54 {
55 while let MirRelationExpr::TopK {
56 input: inner_input,
57 group_key: inner_group_key,
58 order_key: inner_order_key,
59 limit: inner_limit,
60 offset: inner_offset,
61 monotonic: inner_monotonic,
62 expected_group_size: inner_expected_group_size,
63 } = &mut **input
64 {
65 // We can fuse two chained TopK operators as long as they share the
66 // same grouping and ordering key.
67 if *group_key == *inner_group_key && *order_key == *inner_order_key {
68 // Given the following limit/offset pairs:
69 //
70 // inner_offset inner_limit
71 // |------------|xxxxxxxxxxxxxxxxxx|
72 // |------------|xxxxxxxxxxxx|
73 // outer_offset outer_limit
74 //
75 // the limit/offset pair of the fused TopK operator is computed
76 // as:
77 //
78 // offset = inner_offset + outer_offset
79 // limit = min(max(inner_limit - outer_offset, 0), outer_limit)
80 let inner_limit_int64 = inner_limit.as_ref().map(|l| l.as_literal_int64());
81 let outer_limit_int64 = limit.as_ref().map(|l| l.as_literal_int64());
82 // If either limit is an expression rather than a literal, bail out.
83 if inner_limit_int64 == Some(None) || outer_limit_int64 == Some(None) {
84 break;
85 }
86 let inner_limit_int64 = inner_limit_int64.flatten();
87 let outer_limit_int64 = outer_limit_int64.flatten();
88 // If either limit is less than zero, bail out.
89 if inner_limit_int64.map_or(false, |l| l < 0) {
90 break;
91 }
92 if outer_limit_int64.map_or(false, |l| l < 0) {
93 break;
94 }
95
96 let Ok(offset_int64) = i64::try_from(*offset) else {
97 break;
98 };
99
100 if let Some(inner_limit) = inner_limit_int64 {
101 let inner_limit_minus_outer_offset =
102 std::cmp::max(inner_limit - offset_int64, 0);
103 let new_limit = if let Some(outer_limit) = outer_limit_int64 {
104 std::cmp::min(outer_limit, inner_limit_minus_outer_offset)
105 } else {
106 inner_limit_minus_outer_offset
107 };
108 *limit = Some(mz_expr::MirScalarExpr::literal_ok(
109 mz_repr::Datum::Int64(new_limit),
110 mz_repr::ScalarType::Int64,
111 ));
112 }
113
114 if let Some(0) = limit.as_ref().and_then(|l| l.as_literal_int64()) {
115 relation.take_safely(None);
116 break;
117 }
118
119 *offset += *inner_offset;
120 *monotonic = *inner_monotonic;
121
122 // Expected group size is only a hint, and setting it small when the group size
123 // might actually be large would be bad.
124 //
125 // rust-lang/rust#70086 would allow a.zip_with(b, max) here.
126 *inner_expected_group_size =
127 match (&expected_group_size, &inner_expected_group_size) {
128 (Some(a), Some(b)) => Some(std::cmp::max(*a, *b)),
129 _ => None,
130 };
131
132 **input = inner_input.take_dangerous();
133 } else {
134 break;
135 }
136 }
137 }
138 }
139}