1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.
//! Fuses a sequence of `TopK` operators in to one `TopK` operator
use mz_expr::MirRelationExpr;
use crate::TransformCtx;
/// Fuses a sequence of `TopK` operators in to one `TopK` operator if
/// they happen to share the same grouping and ordering key.
#[derive(Debug)]
pub struct TopK;
impl crate::Transform for TopK {
fn name(&self) -> &'static str {
"TopKFusion"
}
#[mz_ore::instrument(
target = "optimizer",
level = "debug",
fields(path.segment = "topk_fusion")
)]
fn actually_perform_transform(
&self,
relation: &mut MirRelationExpr,
_: &mut TransformCtx,
) -> Result<(), crate::TransformError> {
relation.visit_pre_mut(&mut Self::action);
mz_repr::explain::trace_plan(&*relation);
Ok(())
}
}
impl TopK {
/// Fuses a sequence of `TopK` operators in to one `TopK` operator.
pub fn action(relation: &mut MirRelationExpr) {
if let MirRelationExpr::TopK {
input,
group_key,
order_key,
limit,
offset,
monotonic,
expected_group_size,
} = relation
{
while let MirRelationExpr::TopK {
input: inner_input,
group_key: inner_group_key,
order_key: inner_order_key,
limit: inner_limit,
offset: inner_offset,
monotonic: inner_monotonic,
expected_group_size: inner_expected_group_size,
} = &mut **input
{
// We can fuse two chained TopK operators as long as they share the
// same grouping and ordering key.
if *group_key == *inner_group_key && *order_key == *inner_order_key {
// Given the following limit/offset pairs:
//
// inner_offset inner_limit
// |------------|xxxxxxxxxxxxxxxxxx|
// |------------|xxxxxxxxxxxx|
// outer_offset outer_limit
//
// the limit/offset pair of the fused TopK operator is computed
// as:
//
// offset = inner_offset + outer_offset
// limit = min(max(inner_limit - outer_offset, 0), outer_limit)
let inner_limit_int64 = inner_limit.as_ref().map(|l| l.as_literal_int64());
let outer_limit_int64 = limit.as_ref().map(|l| l.as_literal_int64());
// If either limit is an expression rather than a literal, bail out.
if inner_limit_int64 == Some(None) || outer_limit_int64 == Some(None) {
break;
}
let inner_limit_int64 = inner_limit_int64.flatten();
let outer_limit_int64 = outer_limit_int64.flatten();
// If either limit is less than zero, bail out.
if inner_limit_int64.map_or(false, |l| l < 0) {
break;
}
if outer_limit_int64.map_or(false, |l| l < 0) {
break;
}
let Ok(offset_int64) = i64::try_from(*offset) else {
break;
};
if let Some(inner_limit) = inner_limit_int64 {
let inner_limit_minus_outer_offset =
std::cmp::max(inner_limit - offset_int64, 0);
let new_limit = if let Some(outer_limit) = outer_limit_int64 {
std::cmp::min(outer_limit, inner_limit_minus_outer_offset)
} else {
inner_limit_minus_outer_offset
};
*limit = Some(mz_expr::MirScalarExpr::literal_ok(
mz_repr::Datum::Int64(new_limit),
mz_repr::ScalarType::Int64,
));
}
if let Some(0) = limit.as_ref().and_then(|l| l.as_literal_int64()) {
relation.take_safely();
break;
}
*offset += *inner_offset;
*monotonic = *inner_monotonic;
// Expected group size is only a hint, and setting it small when the group size
// might actually be large would be bad.
//
// rust-lang/rust#70086 would allow a.zip_with(b, max) here.
*inner_expected_group_size =
match (&expected_group_size, &inner_expected_group_size) {
(Some(a), Some(b)) => Some(std::cmp::max(*a, *b)),
_ => None,
};
**input = inner_input.take_dangerous();
} else {
break;
}
}
}
}
}