mz_transform/fusion/
top_k.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Fuses a sequence of `TopK` operators in to one `TopK` operator
11
12use mz_expr::MirRelationExpr;
13
14use crate::TransformCtx;
15
16/// Fuses a sequence of `TopK` operators in to one `TopK` operator if
17/// they happen to share the same grouping and ordering key.
18#[derive(Debug)]
19pub struct TopK;
20
21impl crate::Transform for TopK {
22    fn name(&self) -> &'static str {
23        "TopKFusion"
24    }
25
26    #[mz_ore::instrument(
27        target = "optimizer",
28        level = "debug",
29        fields(path.segment = "topk_fusion")
30    )]
31    fn actually_perform_transform(
32        &self,
33        relation: &mut MirRelationExpr,
34        _: &mut TransformCtx,
35    ) -> Result<(), crate::TransformError> {
36        relation.visit_pre_mut(&mut Self::action);
37        mz_repr::explain::trace_plan(&*relation);
38        Ok(())
39    }
40}
41
42impl TopK {
43    /// Fuses a sequence of `TopK` operators in to one `TopK` operator.
44    pub fn action(relation: &mut MirRelationExpr) {
45        if let MirRelationExpr::TopK {
46            input,
47            group_key,
48            order_key,
49            limit,
50            offset,
51            monotonic,
52            expected_group_size,
53        } = relation
54        {
55            while let MirRelationExpr::TopK {
56                input: inner_input,
57                group_key: inner_group_key,
58                order_key: inner_order_key,
59                limit: inner_limit,
60                offset: inner_offset,
61                monotonic: inner_monotonic,
62                expected_group_size: inner_expected_group_size,
63            } = &mut **input
64            {
65                // We can fuse two chained TopK operators as long as they share the
66                // same grouping and ordering key.
67                if *group_key == *inner_group_key && *order_key == *inner_order_key {
68                    // Given the following limit/offset pairs:
69                    //
70                    // inner_offset          inner_limit
71                    // |------------|xxxxxxxxxxxxxxxxxx|
72                    //              |------------|xxxxxxxxxxxx|
73                    //              outer_offset    outer_limit
74                    //
75                    // the limit/offset pair of the fused TopK operator is computed
76                    // as:
77                    //
78                    // offset = inner_offset + outer_offset
79                    // limit = min(max(inner_limit - outer_offset, 0), outer_limit)
80                    let inner_limit_int64 = inner_limit.as_ref().map(|l| l.as_literal_int64());
81                    let outer_limit_int64 = limit.as_ref().map(|l| l.as_literal_int64());
82                    // If either limit is an expression rather than a literal, bail out.
83                    if inner_limit_int64 == Some(None) || outer_limit_int64 == Some(None) {
84                        break;
85                    }
86                    let inner_limit_int64 = inner_limit_int64.flatten();
87                    let outer_limit_int64 = outer_limit_int64.flatten();
88                    // If either limit is less than zero, bail out.
89                    if inner_limit_int64.map_or(false, |l| l < 0) {
90                        break;
91                    }
92                    if outer_limit_int64.map_or(false, |l| l < 0) {
93                        break;
94                    }
95
96                    let Ok(offset_int64) = i64::try_from(*offset) else {
97                        break;
98                    };
99
100                    if let Some(inner_limit) = inner_limit_int64 {
101                        let inner_limit_minus_outer_offset =
102                            std::cmp::max(inner_limit - offset_int64, 0);
103                        let new_limit = if let Some(outer_limit) = outer_limit_int64 {
104                            std::cmp::min(outer_limit, inner_limit_minus_outer_offset)
105                        } else {
106                            inner_limit_minus_outer_offset
107                        };
108                        *limit = Some(mz_expr::MirScalarExpr::literal_ok(
109                            mz_repr::Datum::Int64(new_limit),
110                            mz_repr::ScalarType::Int64,
111                        ));
112                    }
113
114                    if let Some(0) = limit.as_ref().and_then(|l| l.as_literal_int64()) {
115                        relation.take_safely(None);
116                        break;
117                    }
118
119                    *offset += *inner_offset;
120                    *monotonic = *inner_monotonic;
121
122                    // Expected group size is only a hint, and setting it small when the group size
123                    // might actually be large would be bad.
124                    //
125                    // rust-lang/rust#70086 would allow a.zip_with(b, max) here.
126                    *inner_expected_group_size =
127                        match (&expected_group_size, &inner_expected_group_size) {
128                            (Some(a), Some(b)) => Some(std::cmp::max(*a, *b)),
129                            _ => None,
130                        };
131
132                    **input = inner_input.take_dangerous();
133                } else {
134                    break;
135                }
136            }
137        }
138    }
139}