Skip to main content

mz_transform/fusion/
top_k.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Fuses a sequence of `TopK` operators in to one `TopK` operator
11
12use mz_expr::MirRelationExpr;
13use mz_repr::ReprScalarType;
14
15use crate::TransformCtx;
16
17/// Fuses a sequence of `TopK` operators in to one `TopK` operator if
18/// they happen to share the same grouping and ordering key.
19#[derive(Debug)]
20pub struct TopK;
21
22impl crate::Transform for TopK {
23    fn name(&self) -> &'static str {
24        "TopKFusion"
25    }
26
27    #[mz_ore::instrument(
28        target = "optimizer",
29        level = "debug",
30        fields(path.segment = "topk_fusion")
31    )]
32    fn actually_perform_transform(
33        &self,
34        relation: &mut MirRelationExpr,
35        _: &mut TransformCtx,
36    ) -> Result<(), crate::TransformError> {
37        relation.visit_pre_mut(&mut Self::action);
38        mz_repr::explain::trace_plan(&*relation);
39        Ok(())
40    }
41}
42
43impl TopK {
44    /// Fuses a sequence of `TopK` operators in to one `TopK` operator.
45    pub fn action(relation: &mut MirRelationExpr) {
46        if let MirRelationExpr::TopK {
47            input,
48            group_key,
49            order_key,
50            limit,
51            offset,
52            monotonic,
53            expected_group_size,
54        } = relation
55        {
56            while let MirRelationExpr::TopK {
57                input: inner_input,
58                group_key: inner_group_key,
59                order_key: inner_order_key,
60                limit: inner_limit,
61                offset: inner_offset,
62                monotonic: inner_monotonic,
63                expected_group_size: inner_expected_group_size,
64            } = &mut **input
65            {
66                // We can fuse two chained TopK operators as long as they share the
67                // same grouping and ordering key.
68                if *group_key == *inner_group_key && *order_key == *inner_order_key {
69                    // Given the following limit/offset pairs:
70                    //
71                    // inner_offset          inner_limit
72                    // |------------|xxxxxxxxxxxxxxxxxx|
73                    //              |------------|xxxxxxxxxxxx|
74                    //              outer_offset    outer_limit
75                    //
76                    // the limit/offset pair of the fused TopK operator is computed
77                    // as:
78                    //
79                    // offset = inner_offset + outer_offset
80                    // limit = min(max(inner_limit - outer_offset, 0), outer_limit)
81                    let inner_limit_int64 = inner_limit.as_ref().map(|l| l.as_literal_int64());
82                    let outer_limit_int64 = limit.as_ref().map(|l| l.as_literal_int64());
83                    // If either limit is an expression rather than a literal, bail out.
84                    if inner_limit_int64 == Some(None) || outer_limit_int64 == Some(None) {
85                        break;
86                    }
87                    let inner_limit_int64 = inner_limit_int64.flatten();
88                    let outer_limit_int64 = outer_limit_int64.flatten();
89                    // If either limit is less than zero, bail out.
90                    if inner_limit_int64.map_or(false, |l| l < 0) {
91                        break;
92                    }
93                    if outer_limit_int64.map_or(false, |l| l < 0) {
94                        break;
95                    }
96
97                    let Ok(offset_int64) = i64::try_from(*offset) else {
98                        break;
99                    };
100
101                    if let Some(inner_limit) = inner_limit_int64 {
102                        let inner_limit_minus_outer_offset =
103                            std::cmp::max(inner_limit - offset_int64, 0);
104                        let new_limit = if let Some(outer_limit) = outer_limit_int64 {
105                            std::cmp::min(outer_limit, inner_limit_minus_outer_offset)
106                        } else {
107                            inner_limit_minus_outer_offset
108                        };
109                        *limit = Some(mz_expr::MirScalarExpr::literal_ok(
110                            mz_repr::Datum::Int64(new_limit),
111                            ReprScalarType::Int64,
112                        ));
113                    }
114
115                    if let Some(0) = limit.as_ref().and_then(|l| l.as_literal_int64()) {
116                        relation.take_safely(None);
117                        break;
118                    }
119
120                    *offset += *inner_offset;
121                    *monotonic = *inner_monotonic;
122
123                    // Expected group size is only a hint, and setting it small when the group size
124                    // might actually be large would be bad.
125                    //
126                    // rust-lang/rust#70086 would allow a.zip_with(b, max) here.
127                    *inner_expected_group_size =
128                        match (&expected_group_size, &inner_expected_group_size) {
129                            (Some(a), Some(b)) => Some(std::cmp::max(*a, *b)),
130                            _ => None,
131                        };
132
133                    **input = inner_input.take_dangerous();
134                } else {
135                    break;
136                }
137            }
138        }
139    }
140}