mz_transform/fusion/top_k.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Fuses a sequence of `TopK` operators in to one `TopK` operator
11
12use mz_expr::MirRelationExpr;
13use mz_repr::ReprScalarType;
14
15use crate::TransformCtx;
16
17/// Fuses a sequence of `TopK` operators in to one `TopK` operator if
18/// they happen to share the same grouping and ordering key.
19#[derive(Debug)]
20pub struct TopK;
21
22impl crate::Transform for TopK {
23 fn name(&self) -> &'static str {
24 "TopKFusion"
25 }
26
27 #[mz_ore::instrument(
28 target = "optimizer",
29 level = "debug",
30 fields(path.segment = "topk_fusion")
31 )]
32 fn actually_perform_transform(
33 &self,
34 relation: &mut MirRelationExpr,
35 _: &mut TransformCtx,
36 ) -> Result<(), crate::TransformError> {
37 relation.visit_pre_mut(&mut Self::action);
38 mz_repr::explain::trace_plan(&*relation);
39 Ok(())
40 }
41}
42
43impl TopK {
44 /// Fuses a sequence of `TopK` operators in to one `TopK` operator.
45 pub fn action(relation: &mut MirRelationExpr) {
46 if let MirRelationExpr::TopK {
47 input,
48 group_key,
49 order_key,
50 limit,
51 offset,
52 monotonic,
53 expected_group_size,
54 } = relation
55 {
56 while let MirRelationExpr::TopK {
57 input: inner_input,
58 group_key: inner_group_key,
59 order_key: inner_order_key,
60 limit: inner_limit,
61 offset: inner_offset,
62 monotonic: inner_monotonic,
63 expected_group_size: inner_expected_group_size,
64 } = &mut **input
65 {
66 // We can fuse two chained TopK operators as long as they share the
67 // same grouping and ordering key.
68 if *group_key == *inner_group_key && *order_key == *inner_order_key {
69 // Given the following limit/offset pairs:
70 //
71 // inner_offset inner_limit
72 // |------------|xxxxxxxxxxxxxxxxxx|
73 // |------------|xxxxxxxxxxxx|
74 // outer_offset outer_limit
75 //
76 // the limit/offset pair of the fused TopK operator is computed
77 // as:
78 //
79 // offset = inner_offset + outer_offset
80 // limit = min(max(inner_limit - outer_offset, 0), outer_limit)
81 let inner_limit_int64 = inner_limit.as_ref().map(|l| l.as_literal_int64());
82 let outer_limit_int64 = limit.as_ref().map(|l| l.as_literal_int64());
83 // If either limit is an expression rather than a literal, bail out.
84 if inner_limit_int64 == Some(None) || outer_limit_int64 == Some(None) {
85 break;
86 }
87 let inner_limit_int64 = inner_limit_int64.flatten();
88 let outer_limit_int64 = outer_limit_int64.flatten();
89 // If either limit is less than zero, bail out.
90 if inner_limit_int64.map_or(false, |l| l < 0) {
91 break;
92 }
93 if outer_limit_int64.map_or(false, |l| l < 0) {
94 break;
95 }
96
97 let Ok(offset_int64) = i64::try_from(*offset) else {
98 break;
99 };
100
101 if let Some(inner_limit) = inner_limit_int64 {
102 let inner_limit_minus_outer_offset =
103 std::cmp::max(inner_limit - offset_int64, 0);
104 let new_limit = if let Some(outer_limit) = outer_limit_int64 {
105 std::cmp::min(outer_limit, inner_limit_minus_outer_offset)
106 } else {
107 inner_limit_minus_outer_offset
108 };
109 *limit = Some(mz_expr::MirScalarExpr::literal_ok(
110 mz_repr::Datum::Int64(new_limit),
111 ReprScalarType::Int64,
112 ));
113 }
114
115 if let Some(0) = limit.as_ref().and_then(|l| l.as_literal_int64()) {
116 relation.take_safely(None);
117 break;
118 }
119
120 *offset += *inner_offset;
121 *monotonic = *inner_monotonic;
122
123 // Expected group size is only a hint, and setting it small when the group size
124 // might actually be large would be bad.
125 //
126 // rust-lang/rust#70086 would allow a.zip_with(b, max) here.
127 *inner_expected_group_size =
128 match (&expected_group_size, &inner_expected_group_size) {
129 (Some(a), Some(b)) => Some(std::cmp::max(*a, *b)),
130 _ => None,
131 };
132
133 **input = inner_input.take_dangerous();
134 } else {
135 break;
136 }
137 }
138 }
139 }
140}