Skip to main content

mz_expr/relation/
func.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![allow(missing_docs)]
11
12use std::cmp::{max, min};
13use std::iter::Sum;
14use std::ops::Deref;
15use std::str::FromStr;
16use std::{fmt, iter};
17
18use chrono::{DateTime, NaiveDateTime, NaiveTime, Utc};
19use dec::OrderedDecimal;
20use itertools::{Either, Itertools};
21use mz_lowertest::MzReflect;
22use mz_ore::cast::CastFrom;
23
24use mz_ore::str::separated;
25use mz_ore::{soft_assert_eq_no_log, soft_assert_or_log};
26use mz_repr::adt::array::ArrayDimension;
27use mz_repr::adt::date::Date;
28use mz_repr::adt::interval::Interval;
29use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
30use mz_repr::adt::regex::{Regex as ReprRegex, RegexCompilationError};
31use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike};
32use mz_repr::{
33    ColumnName, Datum, Diff, ReprColumnType, ReprRelationType, Row, RowArena, RowPacker, SharedRow,
34    SqlColumnType, SqlRelationType, SqlScalarType, datum_size,
35};
36use num::{CheckedAdd, Integer, Signed, ToPrimitive};
37use ordered_float::OrderedFloat;
38use regex::Regex;
39use serde::{Deserialize, Serialize};
40use smallvec::SmallVec;
41
42use crate::EvalError;
43use crate::WindowFrameBound::{
44    CurrentRow, OffsetFollowing, OffsetPreceding, UnboundedFollowing, UnboundedPreceding,
45};
46use crate::WindowFrameUnits::{Groups, Range, Rows};
47use crate::explain::{HumanizedExpr, HumanizerMode};
48use crate::relation::{
49    ColumnOrder, WindowFrame, WindowFrameBound, WindowFrameUnits, compare_columns,
50};
51use crate::scalar::func::{add_timestamp_months, jsonb_stringify};
52
53// TODO(jamii) be careful about overflow in sum/avg
54// see https://timely.zulipchat.com/#narrow/stream/186635-engineering/topic/additional.20work/near/163507435
55
56fn max_string<'a, I>(datums: I) -> Datum<'a>
57where
58    I: IntoIterator<Item = Datum<'a>>,
59{
60    match datums
61        .into_iter()
62        .filter(|d| !d.is_null())
63        .max_by(|a, b| a.unwrap_str().cmp(b.unwrap_str()))
64    {
65        Some(datum) => datum,
66        None => Datum::Null,
67    }
68}
69
70fn max_datum<'a, I, DatumType>(datums: I) -> Datum<'a>
71where
72    I: IntoIterator<Item = Datum<'a>>,
73    DatumType: TryFrom<Datum<'a>> + Ord,
74    <DatumType as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
75    Datum<'a>: From<Option<DatumType>>,
76{
77    let x: Option<DatumType> = datums
78        .into_iter()
79        .filter(|d| !d.is_null())
80        .map(|d| DatumType::try_from(d).expect("unexpected type"))
81        .max();
82
83    x.into()
84}
85
86fn min_datum<'a, I, DatumType>(datums: I) -> Datum<'a>
87where
88    I: IntoIterator<Item = Datum<'a>>,
89    DatumType: TryFrom<Datum<'a>> + Ord,
90    <DatumType as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
91    Datum<'a>: From<Option<DatumType>>,
92{
93    let x: Option<DatumType> = datums
94        .into_iter()
95        .filter(|d| !d.is_null())
96        .map(|d| DatumType::try_from(d).expect("unexpected type"))
97        .min();
98
99    x.into()
100}
101
102fn min_string<'a, I>(datums: I) -> Datum<'a>
103where
104    I: IntoIterator<Item = Datum<'a>>,
105{
106    match datums
107        .into_iter()
108        .filter(|d| !d.is_null())
109        .min_by(|a, b| a.unwrap_str().cmp(b.unwrap_str()))
110    {
111        Some(datum) => datum,
112        None => Datum::Null,
113    }
114}
115
116fn sum_datum<'a, I, DatumType, ResultType>(datums: I) -> Datum<'a>
117where
118    I: IntoIterator<Item = Datum<'a>>,
119    DatumType: TryFrom<Datum<'a>>,
120    <DatumType as TryFrom<Datum<'a>>>::Error: std::fmt::Debug,
121    ResultType: From<DatumType> + Sum + Into<Datum<'a>>,
122{
123    let mut datums = datums.into_iter().filter(|d| !d.is_null()).peekable();
124    if datums.peek().is_none() {
125        Datum::Null
126    } else {
127        let x = datums
128            .map(|d| ResultType::from(DatumType::try_from(d).expect("unexpected type")))
129            .sum::<ResultType>();
130        x.into()
131    }
132}
133
134/// Count-aware signed-integer sum. Accumulates `Σ value·diff` in `i128`, which
135/// matches the width of the dataflow's `Accum::SimpleNumber` accumulator (see
136/// `build_accumulable` and `finalize_accum` in `mz_compute::render::reduce`);
137/// `narrow` then reproduces that variant's `finalize_accum` arm. Unlike
138/// `expand_counts`, this consumes the multiplicity directly, so it is linear in
139/// the number of distinct values and correct for negative diffs (retractions),
140/// which `expand_counts` would silently drop.
141///
142/// Returns `Datum::Null` when no non-null value was accumulated, matching
143/// `finalize_accum`'s null handling: its `is_zero` check on `SimpleNumber`
144/// requires both a zero running sum and a zero non-null count.
145fn sum_signed_int_counted<'a, I, N>(datums: I, narrow: N) -> Datum<'a>
146where
147    I: IntoIterator<Item = (Datum<'a>, Diff)>,
148    N: FnOnce(i128) -> Datum<'a>,
149{
150    let mut accum: i128 = 0;
151    let mut non_nulls = Diff::ZERO;
152    for (datum, diff) in datums {
153        if datum.is_null() {
154            continue;
155        }
156        let value = match datum {
157            Datum::Int16(i) => i128::from(i),
158            Datum::Int32(i) => i128::from(i),
159            Datum::Int64(i) => i128::from(i),
160            other => panic!("unexpected non-integer datum in signed sum: {other:?}"),
161        };
162        // The dataflow accumulates `value * diff` in an `Overflowing<i128>`; we
163        // mirror that. Genuine i128 overflow would require summands far beyond
164        // any realistic input, so wrapping matches the dataflow's production
165        // behavior.
166        accum = accum.wrapping_add(value.wrapping_mul(i128::from(diff.into_inner())));
167        non_nulls += diff;
168    }
169    if accum == 0 && non_nulls.is_zero() {
170        Datum::Null
171    } else {
172        narrow(accum)
173    }
174}
175
176fn sum_numeric<'a, I>(datums: I) -> Datum<'a>
177where
178    I: IntoIterator<Item = Datum<'a>>,
179{
180    let mut cx = numeric::cx_datum();
181    let mut sum = Numeric::zero();
182    let mut empty = true;
183    for d in datums {
184        if !d.is_null() {
185            empty = false;
186            cx.add(&mut sum, &d.unwrap_numeric().0);
187        }
188    }
189    match empty {
190        true => Datum::Null,
191        false => Datum::from(sum),
192    }
193}
194
195fn count<'a, I>(datums: I) -> Datum<'a>
196where
197    I: IntoIterator<Item = (Datum<'a>, Diff)>,
198{
199    // Count is accumulable: rather than expand each `(datum, diff)` into `diff`
200    // copies and count them, we sum the diffs directly. A net-negative count is
201    // possible (the surface does not define behavior in that case) and surfaces
202    // here as a negative result.
203    // TODO(jkosh44) This should error when the count can't fit inside of an `i64` instead of returning a negative result.
204    let mut count = Diff::ZERO;
205    for (datum, diff) in datums {
206        if !datum.is_null() {
207            count += diff;
208        }
209    }
210    Datum::from(count.into_inner())
211}
212
213fn any<'a, I>(datums: I) -> Datum<'a>
214where
215    I: IntoIterator<Item = Datum<'a>>,
216{
217    datums
218        .into_iter()
219        .fold(Datum::False, |state, next| match (state, next) {
220            (Datum::True, _) | (_, Datum::True) => Datum::True,
221            (Datum::Null, _) | (_, Datum::Null) => Datum::Null,
222            _ => Datum::False,
223        })
224}
225
226fn all<'a, I>(datums: I) -> Datum<'a>
227where
228    I: IntoIterator<Item = Datum<'a>>,
229{
230    datums
231        .into_iter()
232        .fold(Datum::True, |state, next| match (state, next) {
233            (Datum::False, _) | (_, Datum::False) => Datum::False,
234            (Datum::Null, _) | (_, Datum::Null) => Datum::Null,
235            _ => Datum::True,
236        })
237}
238
239fn string_agg<'a, I>(datums: I, temp_storage: &'a RowArena, order_by: &[ColumnOrder]) -> Datum<'a>
240where
241    I: IntoIterator<Item = Datum<'a>>,
242{
243    const EMPTY_SEP: &str = "";
244
245    let datums = order_aggregate_datums(datums, order_by);
246    let mut sep_value_pairs = datums.into_iter().filter_map(|d| {
247        if d.is_null() {
248            return None;
249        }
250        let mut value_sep = d.unwrap_list().iter();
251        match (value_sep.next().unwrap(), value_sep.next().unwrap()) {
252            (Datum::Null, _) => None,
253            (Datum::String(val), Datum::Null) => Some((EMPTY_SEP, val)),
254            (Datum::String(val), Datum::String(sep)) => Some((sep, val)),
255            _ => unreachable!(),
256        }
257    });
258
259    let mut s = String::default();
260    match sep_value_pairs.next() {
261        // First value not prefixed by its separator
262        Some((_, value)) => s.push_str(value),
263        // If no non-null values sent, return NULL.
264        None => return Datum::Null,
265    }
266
267    for (sep, value) in sep_value_pairs {
268        s.push_str(sep);
269        s.push_str(value);
270    }
271
272    Datum::String(temp_storage.push_string(s))
273}
274
275fn jsonb_agg<'a, I>(datums: I, temp_storage: &'a RowArena, order_by: &[ColumnOrder]) -> Datum<'a>
276where
277    I: IntoIterator<Item = Datum<'a>>,
278{
279    let datums = order_aggregate_datums(datums, order_by);
280    temp_storage.make_datum(|packer| {
281        packer.push_list(datums.into_iter().filter(|d| !d.is_null()));
282    })
283}
284
285fn dict_agg<'a, I>(datums: I, temp_storage: &'a RowArena, order_by: &[ColumnOrder]) -> Datum<'a>
286where
287    I: IntoIterator<Item = Datum<'a>>,
288{
289    let datums = order_aggregate_datums(datums, order_by);
290    temp_storage.make_datum(|packer| {
291        let mut datums: Vec<_> = datums
292            .into_iter()
293            .filter_map(|d| {
294                if d.is_null() {
295                    return None;
296                }
297                let mut list = d.unwrap_list().iter();
298                let key = list.next().unwrap();
299                let val = list.next().unwrap();
300                if key.is_null() {
301                    // TODO(benesch): this should produce an error, but
302                    // aggregate functions cannot presently produce errors.
303                    None
304                } else {
305                    Some((key.unwrap_str(), val))
306                }
307            })
308            .collect();
309        // datums are ordered by any ORDER BY clause now, and we want to preserve
310        // the last entry for each key, but we also need to present unique and sorted
311        // keys to push_dict. Use sort_by here, which is stable, and so will preserve
312        // the ORDER BY order. Then reverse and dedup to retain the last of each
313        // key. Reverse again so we're back in push_dict order.
314        datums.sort_by_key(|(k, _v)| *k);
315        datums.reverse();
316        datums.dedup_by_key(|(k, _v)| *k);
317        datums.reverse();
318        packer.push_dict(datums);
319    })
320}
321
322/// Assuming datums is a List, sort them by the 2nd through Nth elements
323/// corresponding to order_by, then return the 1st element.
324///
325/// Near the usages of this function, we sometimes want to produce Datums with a shorter lifetime
326/// than 'a. We have to actually perform the shortening of the lifetime here, inside this function,
327/// because if we were to simply return `impl Iterator<Item = Datum<'a>>`, that wouldn't be
328/// covariant in the item type, because opaque types are always invariant. (Contrast this with how
329/// we perform the shortening _inside_ this function: the input of the `map` is known to
330/// specifically be `std::vec::IntoIter`, which is known to be covariant.)
331pub fn order_aggregate_datums<'a: 'b, 'b, I>(
332    datums: I,
333    order_by: &[ColumnOrder],
334) -> impl Iterator<Item = Datum<'b>>
335where
336    I: IntoIterator<Item = Datum<'a>>,
337{
338    order_aggregate_datums_with_rank_inner(datums, order_by)
339        .into_iter()
340        // (`payload` is coerced here to `Datum<'b>` in the argument of the closure)
341        .map(|(payload, _order_datums)| payload)
342}
343
344/// Assuming datums is a List, sort them by the 2nd through Nth elements
345/// corresponding to order_by, then return the 1st element and computed order by expression.
346fn order_aggregate_datums_with_rank<'a, I>(
347    datums: I,
348    order_by: &[ColumnOrder],
349) -> impl Iterator<Item = (Datum<'a>, Row)>
350where
351    I: IntoIterator<Item = Datum<'a>>,
352{
353    order_aggregate_datums_with_rank_inner(datums, order_by)
354        .into_iter()
355        .map(|(payload, order_by_datums)| (payload, Row::pack(order_by_datums)))
356}
357
358fn order_aggregate_datums_with_rank_inner<'a, I>(
359    datums: I,
360    order_by: &[ColumnOrder],
361) -> Vec<(Datum<'a>, Vec<Datum<'a>>)>
362where
363    I: IntoIterator<Item = Datum<'a>>,
364{
365    let mut decoded: Vec<(Datum, Vec<Datum>)> = datums
366        .into_iter()
367        .map(|d| {
368            let list = d.unwrap_list();
369            let mut list_it = list.iter();
370            let payload = list_it.next().unwrap();
371
372            // We decode the order_by Datums here instead of the comparison function, because the
373            // comparison function is expected to be called `O(log n)` times on each input row.
374            // The only downside is that the decoded data might be bigger, but I think that's fine,
375            // because:
376            // - if we have a window partition so big that this would create a memory problem, then
377            //   the non-incrementalness of window functions will create a serious CPU problem
378            //   anyway,
379            // - and anyhow various other parts of the window function code already do decoding
380            //   upfront.
381            let mut order_by_datums = Vec::with_capacity(order_by.len());
382            for _ in 0..order_by.len() {
383                order_by_datums.push(
384                    list_it
385                        .next()
386                        .expect("must have exactly the same number of Datums as `order_by`"),
387                );
388            }
389
390            (payload, order_by_datums)
391        })
392        .collect();
393
394    let mut sort_by =
395        |(payload_left, left_order_by_datums): &(Datum, Vec<Datum>),
396         (payload_right, right_order_by_datums): &(Datum, Vec<Datum>)| {
397            compare_columns(
398                order_by,
399                left_order_by_datums,
400                right_order_by_datums,
401                || payload_left.cmp(payload_right),
402            )
403        };
404    // `sort_unstable_by` can be faster and uses less memory than `sort_by`. An unstable sort is
405    // enough here, because if two elements are equal in our `compare` function, then the elements
406    // are actually binary-equal (because of the `tiebreaker` given to `compare_columns`), so it
407    // doesn't matter what order they end up in.
408    decoded.sort_unstable_by(&mut sort_by);
409    decoded
410}
411
412fn array_concat<'a, I>(datums: I, temp_storage: &'a RowArena, order_by: &[ColumnOrder]) -> Datum<'a>
413where
414    I: IntoIterator<Item = Datum<'a>>,
415{
416    let datums = order_aggregate_datums(datums, order_by);
417    let datums: Vec<_> = datums
418        .into_iter()
419        .map(|d| d.unwrap_array().elements().iter())
420        .flatten()
421        .collect();
422    let dims = ArrayDimension {
423        lower_bound: 1,
424        length: datums.len(),
425    };
426    temp_storage.make_datum(|packer| {
427        packer.try_push_array(&[dims], datums).unwrap();
428    })
429}
430
431fn list_concat<'a, I>(datums: I, temp_storage: &'a RowArena, order_by: &[ColumnOrder]) -> Datum<'a>
432where
433    I: IntoIterator<Item = Datum<'a>>,
434{
435    let datums = order_aggregate_datums(datums, order_by);
436    temp_storage.make_datum(|packer| {
437        packer.push_list(datums.into_iter().map(|d| d.unwrap_list().iter()).flatten());
438    })
439}
440
441/// The expected input is in the format of `[((OriginalRow, [EncodedArgs]), OrderByExprs...)]`
442/// The output is in the format of `[result_value, original_row]`.
443/// See an example at `lag_lead`, where the input-output formats are similar.
444fn row_number<'a, I>(
445    datums: I,
446    callers_temp_storage: &'a RowArena,
447    order_by: &[ColumnOrder],
448) -> Datum<'a>
449where
450    I: IntoIterator<Item = Datum<'a>>,
451{
452    // We want to use our own temp_storage here, to avoid flooding `callers_temp_storage` with a
453    // large number of new datums. This is because we don't want to make an assumption about
454    // whether the caller creates a new temp_storage between window partitions.
455    let temp_storage = RowArena::new();
456    let datums = row_number_no_list(datums, &temp_storage, order_by);
457
458    callers_temp_storage.make_datum(|packer| {
459        packer.push_list(datums);
460    })
461}
462
463/// Like `row_number`, but doesn't perform the final wrapping in a list, returning an Iterator
464/// instead.
465fn row_number_no_list<'a: 'b, 'b, I>(
466    datums: I,
467    callers_temp_storage: &'b RowArena,
468    order_by: &[ColumnOrder],
469) -> impl Iterator<Item = Datum<'b>>
470where
471    I: IntoIterator<Item = Datum<'a>>,
472{
473    let datums = order_aggregate_datums(datums, order_by);
474
475    callers_temp_storage.reserve(datums.size_hint().0);
476    #[allow(clippy::disallowed_methods)]
477    datums
478        .into_iter()
479        .map(|d| d.unwrap_list().iter())
480        .flatten()
481        .zip(1i64..)
482        .map(|(d, i)| {
483            callers_temp_storage.make_datum(|packer| {
484                packer.push_list_with(|packer| {
485                    packer.push(Datum::Int64(i));
486                    packer.push(d);
487                });
488            })
489        })
490}
491
492/// The expected input is in the format of `[((OriginalRow, [EncodedArgs]), OrderByExprs...)]`
493/// The output is in the format of `[result_value, original_row]`.
494/// See an example at `lag_lead`, where the input-output formats are similar.
495fn rank<'a, I>(datums: I, callers_temp_storage: &'a RowArena, order_by: &[ColumnOrder]) -> Datum<'a>
496where
497    I: IntoIterator<Item = Datum<'a>>,
498{
499    let temp_storage = RowArena::new();
500    let datums = rank_no_list(datums, &temp_storage, order_by);
501
502    callers_temp_storage.make_datum(|packer| {
503        packer.push_list(datums);
504    })
505}
506
507/// Like `rank`, but doesn't perform the final wrapping in a list, returning an Iterator
508/// instead.
509fn rank_no_list<'a: 'b, 'b, I>(
510    datums: I,
511    callers_temp_storage: &'b RowArena,
512    order_by: &[ColumnOrder],
513) -> impl Iterator<Item = Datum<'b>>
514where
515    I: IntoIterator<Item = Datum<'a>>,
516{
517    // Keep the row used for ordering around, as it is used to determine the rank
518    let datums = order_aggregate_datums_with_rank(datums, order_by);
519
520    let mut datums = datums
521        .into_iter()
522        .map(|(d0, order_row)| {
523            d0.unwrap_list()
524                .iter()
525                .map(move |d1| (d1, order_row.clone()))
526        })
527        .flatten();
528
529    callers_temp_storage.reserve(datums.size_hint().0);
530    datums
531        .next()
532        .map_or(vec![], |(first_datum, first_order_row)| {
533            // Folding with (last order_by row, last assigned rank,
534            // row number, output vec)
535            datums.fold(
536                (first_order_row, 1, 1, vec![(first_datum, 1)]),
537                |mut acc, (next_datum, next_order_row)| {
538                let (ref mut acc_row, ref mut acc_rank, ref mut acc_row_num, ref mut output) = acc;
539                *acc_row_num += 1;
540                // Identity is based on the order_by expression
541                if *acc_row != next_order_row {
542                    *acc_rank = *acc_row_num;
543                    *acc_row = next_order_row;
544                }
545
546                (*output).push((next_datum, *acc_rank));
547                acc
548            })
549        }.3).into_iter().map(|(d, i)| {
550        callers_temp_storage.make_datum(|packer| {
551            packer.push_list_with(|packer| {
552                packer.push(Datum::Int64(i));
553                packer.push(d);
554            });
555        })
556    })
557}
558
559/// The expected input is in the format of `[((OriginalRow, [EncodedArgs]), OrderByExprs...)]`
560/// The output is in the format of `[result_value, original_row]`.
561/// See an example at `lag_lead`, where the input-output formats are similar.
562fn dense_rank<'a, I>(
563    datums: I,
564    callers_temp_storage: &'a RowArena,
565    order_by: &[ColumnOrder],
566) -> Datum<'a>
567where
568    I: IntoIterator<Item = Datum<'a>>,
569{
570    let temp_storage = RowArena::new();
571    let datums = dense_rank_no_list(datums, &temp_storage, order_by);
572
573    callers_temp_storage.make_datum(|packer| {
574        packer.push_list(datums);
575    })
576}
577
578/// Like `dense_rank`, but doesn't perform the final wrapping in a list, returning an Iterator
579/// instead.
580fn dense_rank_no_list<'a: 'b, 'b, I>(
581    datums: I,
582    callers_temp_storage: &'b RowArena,
583    order_by: &[ColumnOrder],
584) -> impl Iterator<Item = Datum<'b>>
585where
586    I: IntoIterator<Item = Datum<'a>>,
587{
588    // Keep the row used for ordering around, as it is used to determine the rank
589    let datums = order_aggregate_datums_with_rank(datums, order_by);
590
591    let mut datums = datums
592        .into_iter()
593        .map(|(d0, order_row)| {
594            d0.unwrap_list()
595                .iter()
596                .map(move |d1| (d1, order_row.clone()))
597        })
598        .flatten();
599
600    callers_temp_storage.reserve(datums.size_hint().0);
601    datums
602        .next()
603        .map_or(vec![], |(first_datum, first_order_row)| {
604            // Folding with (last order_by row, last assigned rank,
605            // output vec)
606            datums.fold(
607                (first_order_row, 1, vec![(first_datum, 1)]),
608                |mut acc, (next_datum, next_order_row)| {
609                let (ref mut acc_row, ref mut acc_rank, ref mut output) = acc;
610                // Identity is based on the order_by expression
611                if *acc_row != next_order_row {
612                    *acc_rank += 1;
613                    *acc_row = next_order_row;
614                }
615
616                (*output).push((next_datum, *acc_rank));
617                acc
618            })
619        }.2).into_iter().map(|(d, i)| {
620        callers_temp_storage.make_datum(|packer| {
621            packer.push_list_with(|packer| {
622                packer.push(Datum::Int64(i));
623                packer.push(d);
624            });
625        })
626    })
627}
628
629/// The expected input is in the format of `[((OriginalRow, EncodedArgs), OrderByExprs...)]`
630/// For example,
631///
632/// lag(x*y, 1, null) over (partition by x+y order by x-y, x/y)
633///
634/// list of:
635/// row(
636///   row(
637///     row(#0, #1),
638///     row((#0 * #1), 1, null)
639///   ),
640///   (#0 - #1),
641///   (#0 / #1)
642/// )
643///
644/// The output is in the format of `[result_value, original_row]`, e.g.
645/// list of:
646/// row(
647///   42,
648///   row(7, 8)
649/// )
650fn lag_lead<'a, I>(
651    datums: I,
652    callers_temp_storage: &'a RowArena,
653    order_by: &[ColumnOrder],
654    lag_lead_type: &LagLeadType,
655    ignore_nulls: &bool,
656) -> Datum<'a>
657where
658    I: IntoIterator<Item = Datum<'a>>,
659{
660    let temp_storage = RowArena::new();
661    let iter = lag_lead_no_list(datums, &temp_storage, order_by, lag_lead_type, ignore_nulls);
662    callers_temp_storage.make_datum(|packer| {
663        packer.push_list(iter);
664    })
665}
666
667/// Like `lag_lead`, but doesn't perform the final wrapping in a list, returning an Iterator
668/// instead.
669fn lag_lead_no_list<'a: 'b, 'b, I>(
670    datums: I,
671    callers_temp_storage: &'b RowArena,
672    order_by: &[ColumnOrder],
673    lag_lead_type: &LagLeadType,
674    ignore_nulls: &bool,
675) -> impl Iterator<Item = Datum<'b>>
676where
677    I: IntoIterator<Item = Datum<'a>>,
678{
679    // Sort the datums according to the ORDER BY expressions and return the (OriginalRow, EncodedArgs) record
680    let datums = order_aggregate_datums(datums, order_by);
681
682    // Take the (OriginalRow, EncodedArgs) records and unwrap them into separate datums.
683    // EncodedArgs = (InputValue, Offset, DefaultValue) for Lag/Lead
684    // (`OriginalRow` is kept in a record form, as we don't need to look inside that.)
685    let (orig_rows, unwrapped_args): (Vec<_>, Vec<_>) = datums
686        .into_iter()
687        .map(|d| {
688            let mut iter = d.unwrap_list().iter();
689            let original_row = iter.next().unwrap();
690            let (input_value, offset, default_value) =
691                unwrap_lag_lead_encoded_args(iter.next().unwrap());
692            (original_row, (input_value, offset, default_value))
693        })
694        .unzip();
695
696    let result = lag_lead_inner(unwrapped_args, lag_lead_type, ignore_nulls);
697
698    callers_temp_storage.reserve(result.len());
699    result
700        .into_iter()
701        .zip_eq(orig_rows)
702        .map(|(result_value, original_row)| {
703            callers_temp_storage.make_datum(|packer| {
704                packer.push_list_with(|packer| {
705                    packer.push(result_value);
706                    packer.push(original_row);
707                });
708            })
709        })
710}
711
712/// lag/lead's arguments are in a record. This function unwraps this record.
713fn unwrap_lag_lead_encoded_args(encoded_args: Datum) -> (Datum, Datum, Datum) {
714    let mut encoded_args_iter = encoded_args.unwrap_list().iter();
715    let (input_value, offset, default_value) = (
716        encoded_args_iter.next().unwrap(),
717        encoded_args_iter.next().unwrap(),
718        encoded_args_iter.next().unwrap(),
719    );
720    (input_value, offset, default_value)
721}
722
723/// Each element of `args` has the 3 arguments evaluated for a single input row.
724/// Returns the results for each input row.
725fn lag_lead_inner<'a>(
726    args: Vec<(Datum<'a>, Datum<'a>, Datum<'a>)>,
727    lag_lead_type: &LagLeadType,
728    ignore_nulls: &bool,
729) -> Vec<Datum<'a>> {
730    if *ignore_nulls {
731        lag_lead_inner_ignore_nulls(args, lag_lead_type)
732    } else {
733        lag_lead_inner_respect_nulls(args, lag_lead_type)
734    }
735}
736
737fn lag_lead_inner_respect_nulls<'a>(
738    args: Vec<(Datum<'a>, Datum<'a>, Datum<'a>)>,
739    lag_lead_type: &LagLeadType,
740) -> Vec<Datum<'a>> {
741    let mut result: Vec<Datum> = Vec::with_capacity(args.len());
742    for (idx, (_, offset, default_value)) in args.iter().enumerate() {
743        // Null offsets are acceptable, and always return null
744        if offset.is_null() {
745            result.push(Datum::Null);
746            continue;
747        }
748
749        let idx = i64::try_from(idx).expect("Array index does not fit in i64");
750        let offset = i64::from(offset.unwrap_int32());
751        let offset = match lag_lead_type {
752            LagLeadType::Lag => -offset,
753            LagLeadType::Lead => offset,
754        };
755
756        // Get a Datum from `datums`. Return None if index is out of range.
757        let datums_get = |i: i64| -> Option<Datum> {
758            match u64::try_from(i) {
759                Ok(i) => args
760                    .get(usize::cast_from(i))
761                    .map(|d| Some(d.0)) // succeeded in getting a Datum from the vec
762                    .unwrap_or(None), // overindexing
763                Err(_) => None, // underindexing (negative index)
764            }
765        };
766
767        let lagged_value = datums_get(idx + offset).unwrap_or(*default_value);
768
769        result.push(lagged_value);
770    }
771
772    result
773}
774
775// `i64` indexes get involved in this function because it's convenient to allow negative indexes and
776// have `datums_get` fail on them, and thus handle the beginning and end of the input vector
777// uniformly, rather than checking underflow separately during index manipulations.
778#[allow(clippy::as_conversions)]
779fn lag_lead_inner_ignore_nulls<'a>(
780    args: Vec<(Datum<'a>, Datum<'a>, Datum<'a>)>,
781    lag_lead_type: &LagLeadType,
782) -> Vec<Datum<'a>> {
783    // We check here once that even the largest index fits in `i64`, and then do silent `as`
784    // conversions from `usize` indexes to `i64` indexes throughout this function.
785    if i64::try_from(args.len()).is_err() {
786        panic!("window partition way too big")
787    }
788    // Preparation: Make sure we can jump over a run of nulls in constant time, i.e., regardless of
789    // how many nulls the run has. The following skip tables will point to the next non-null index.
790    let mut skip_nulls_backward = vec![None; args.len()];
791    let mut last_non_null: i64 = -1;
792    let pairs = args
793        .iter()
794        .enumerate()
795        .zip_eq(skip_nulls_backward.iter_mut());
796    for ((i, (d, _, _)), slot) in pairs {
797        if d.is_null() {
798            *slot = Some(last_non_null);
799        } else {
800            last_non_null = i as i64;
801        }
802    }
803    let mut skip_nulls_forward = vec![None; args.len()];
804    let mut last_non_null: i64 = args.len() as i64;
805    let pairs = args
806        .iter()
807        .enumerate()
808        .rev()
809        .zip_eq(skip_nulls_forward.iter_mut().rev());
810    for ((i, (d, _, _)), slot) in pairs {
811        if d.is_null() {
812            *slot = Some(last_non_null);
813        } else {
814            last_non_null = i as i64;
815        }
816    }
817
818    // The actual computation.
819    let mut result: Vec<Datum> = Vec::with_capacity(args.len());
820    for (idx, (_, offset, default_value)) in args.iter().enumerate() {
821        // Null offsets are acceptable, and always return null
822        if offset.is_null() {
823            result.push(Datum::Null);
824            continue;
825        }
826
827        let idx = idx as i64; // checked at the beginning of the function that len() fits
828        let offset = i64::cast_from(offset.unwrap_int32());
829        let offset = match lag_lead_type {
830            LagLeadType::Lag => -offset,
831            LagLeadType::Lead => offset,
832        };
833        let increment = offset.signum();
834
835        // Get a Datum from `datums`. Return None if index is out of range.
836        let datums_get = |i: i64| -> Option<Datum> {
837            match u64::try_from(i) {
838                Ok(i) => args
839                    .get(usize::cast_from(i))
840                    .map(|d| Some(d.0)) // succeeded in getting a Datum from the vec
841                    .unwrap_or(None), // overindexing
842                Err(_) => None, // underindexing (negative index)
843            }
844        };
845
846        let lagged_value = if increment != 0 {
847            // We start j from idx, and step j until we have seen an abs(offset) number of non-null
848            // values or reach the beginning or end of the partition.
849            //
850            // If offset is big, then this is slow: Considering the entire function, it's
851            // `O(partition_size * offset)`.
852            // However, a common use case is an offset of 1, for which this doesn't matter.
853            // TODO: For larger offsets, we could have a completely different implementation
854            // that starts the inner loop from the index where we found the previous result:
855            // https://github.com/MaterializeInc/materialize/pull/29287#discussion_r1738695174
856            let mut j = idx;
857            for _ in 0..num::abs(offset) {
858                j += increment;
859                // Jump over a run of nulls
860                if datums_get(j).is_some_and(|d| d.is_null()) {
861                    let ju = j as usize; // `j >= 0` because of the above `is_some_and`
862                    if increment > 0 {
863                        j = skip_nulls_forward[ju].expect("checked above that it's null");
864                    } else {
865                        j = skip_nulls_backward[ju].expect("checked above that it's null");
866                    }
867                }
868                if datums_get(j).is_none() {
869                    break;
870                }
871            }
872            match datums_get(j) {
873                Some(datum) => datum,
874                None => *default_value,
875            }
876        } else {
877            assert_eq!(offset, 0);
878            let datum = datums_get(idx).expect("known to exist");
879            if !datum.is_null() {
880                datum
881            } else {
882                // Not clear what should the semantics be here. See
883                // https://github.com/MaterializeInc/database-issues/issues/8497
884                // (We used to run into an infinite loop in this case, so panicking is
885                // better.)
886                panic!("0 offset in lag/lead IGNORE NULLS");
887            }
888        };
889
890        result.push(lagged_value);
891    }
892
893    result
894}
895
896/// The expected input is in the format of [((OriginalRow, InputValue), OrderByExprs...)]
897fn first_value<'a, I>(
898    datums: I,
899    callers_temp_storage: &'a RowArena,
900    order_by: &[ColumnOrder],
901    window_frame: &WindowFrame,
902) -> Datum<'a>
903where
904    I: IntoIterator<Item = Datum<'a>>,
905{
906    let temp_storage = RowArena::new();
907    let iter = first_value_no_list(datums, &temp_storage, order_by, window_frame);
908    callers_temp_storage.make_datum(|packer| {
909        packer.push_list(iter);
910    })
911}
912
913/// Like `first_value`, but doesn't perform the final wrapping in a list, returning an Iterator
914/// instead.
915fn first_value_no_list<'a: 'b, 'b, I>(
916    datums: I,
917    callers_temp_storage: &'b RowArena,
918    order_by: &[ColumnOrder],
919    window_frame: &WindowFrame,
920) -> impl Iterator<Item = Datum<'b>>
921where
922    I: IntoIterator<Item = Datum<'a>>,
923{
924    // Sort the datums according to the ORDER BY expressions and return the (OriginalRow, InputValue) record
925    let datums = order_aggregate_datums(datums, order_by);
926
927    // Decode the input (OriginalRow, InputValue) into separate datums
928    let (orig_rows, args): (Vec<_>, Vec<_>) = datums
929        .into_iter()
930        .map(|d| {
931            let mut iter = d.unwrap_list().iter();
932            let original_row = iter.next().unwrap();
933            let arg = iter.next().unwrap();
934
935            (original_row, arg)
936        })
937        .unzip();
938
939    let results = first_value_inner(args, window_frame);
940
941    callers_temp_storage.reserve(results.len());
942    results
943        .into_iter()
944        .zip_eq(orig_rows)
945        .map(|(result_value, original_row)| {
946            callers_temp_storage.make_datum(|packer| {
947                packer.push_list_with(|packer| {
948                    packer.push(result_value);
949                    packer.push(original_row);
950                });
951            })
952        })
953}
954
955fn first_value_inner<'a>(datums: Vec<Datum<'a>>, window_frame: &WindowFrame) -> Vec<Datum<'a>> {
956    let length = datums.len();
957    let mut result: Vec<Datum> = Vec::with_capacity(length);
958    for (idx, current_datum) in datums.iter().enumerate() {
959        let first_value = match &window_frame.start_bound {
960            // Always return the current value
961            WindowFrameBound::CurrentRow => *current_datum,
962            WindowFrameBound::UnboundedPreceding => {
963                if let WindowFrameBound::OffsetPreceding(end_offset) = &window_frame.end_bound {
964                    let end_offset = usize::cast_from(*end_offset);
965
966                    // If the frame ends before the first row, return null
967                    if idx < end_offset {
968                        Datum::Null
969                    } else {
970                        datums[0]
971                    }
972                } else {
973                    datums[0]
974                }
975            }
976            WindowFrameBound::OffsetPreceding(offset) => {
977                let start_offset = usize::cast_from(*offset);
978                let start_idx = idx.saturating_sub(start_offset);
979                if let WindowFrameBound::OffsetPreceding(end_offset) = &window_frame.end_bound {
980                    let end_offset = usize::cast_from(*end_offset);
981
982                    // If the frame is empty or ends before the first row, return null
983                    if start_offset < end_offset || idx < end_offset {
984                        Datum::Null
985                    } else {
986                        datums[start_idx]
987                    }
988                } else {
989                    datums[start_idx]
990                }
991            }
992            WindowFrameBound::OffsetFollowing(offset) => {
993                let start_offset = usize::cast_from(*offset);
994                let start_idx = idx.saturating_add(start_offset);
995                if let WindowFrameBound::OffsetFollowing(end_offset) = &window_frame.end_bound {
996                    // If the frame is empty or starts after the last row, return null
997                    if offset > end_offset || start_idx >= length {
998                        Datum::Null
999                    } else {
1000                        datums[start_idx]
1001                    }
1002                } else {
1003                    datums
1004                        .get(start_idx)
1005                        .map(|d| d.clone())
1006                        .unwrap_or(Datum::Null)
1007                }
1008            }
1009            // Forbidden during planning
1010            WindowFrameBound::UnboundedFollowing => unreachable!(),
1011        };
1012        result.push(first_value);
1013    }
1014    result
1015}
1016
1017/// The expected input is in the format of [((OriginalRow, InputValue), OrderByExprs...)]
1018fn last_value<'a, I>(
1019    datums: I,
1020    callers_temp_storage: &'a RowArena,
1021    order_by: &[ColumnOrder],
1022    window_frame: &WindowFrame,
1023) -> Datum<'a>
1024where
1025    I: IntoIterator<Item = Datum<'a>>,
1026{
1027    let temp_storage = RowArena::new();
1028    let iter = last_value_no_list(datums, &temp_storage, order_by, window_frame);
1029    callers_temp_storage.make_datum(|packer| {
1030        packer.push_list(iter);
1031    })
1032}
1033
1034/// Like `last_value`, but doesn't perform the final wrapping in a list, returning an Iterator
1035/// instead.
1036fn last_value_no_list<'a: 'b, 'b, I>(
1037    datums: I,
1038    callers_temp_storage: &'b RowArena,
1039    order_by: &[ColumnOrder],
1040    window_frame: &WindowFrame,
1041) -> impl Iterator<Item = Datum<'b>>
1042where
1043    I: IntoIterator<Item = Datum<'a>>,
1044{
1045    // Sort the datums according to the ORDER BY expressions and return the ((OriginalRow, InputValue), OrderByRow) record
1046    // The OrderByRow is kept around because it is required to compute the peer groups in RANGE mode
1047    let datums = order_aggregate_datums_with_rank(datums, order_by);
1048
1049    // Decode the input (OriginalRow, InputValue) into separate datums, while keeping the OrderByRow
1050    let size_hint = datums.size_hint().0;
1051    let mut args = Vec::with_capacity(size_hint);
1052    let mut original_rows = Vec::with_capacity(size_hint);
1053    let mut order_by_rows = Vec::with_capacity(size_hint);
1054    for (d, order_by_row) in datums.into_iter() {
1055        let mut iter = d.unwrap_list().iter();
1056        let original_row = iter.next().unwrap();
1057        let arg = iter.next().unwrap();
1058        order_by_rows.push(order_by_row);
1059        original_rows.push(original_row);
1060        args.push(arg);
1061    }
1062
1063    let results = last_value_inner(args, &order_by_rows, window_frame);
1064
1065    callers_temp_storage.reserve(results.len());
1066    results
1067        .into_iter()
1068        .zip_eq(original_rows)
1069        .map(|(result_value, original_row)| {
1070            callers_temp_storage.make_datum(|packer| {
1071                packer.push_list_with(|packer| {
1072                    packer.push(result_value);
1073                    packer.push(original_row);
1074                });
1075            })
1076        })
1077}
1078
1079fn last_value_inner<'a>(
1080    args: Vec<Datum<'a>>,
1081    order_by_rows: &Vec<Row>,
1082    window_frame: &WindowFrame,
1083) -> Vec<Datum<'a>> {
1084    let length = args.len();
1085    let mut results: Vec<Datum> = Vec::with_capacity(length);
1086    for (idx, (current_datum, order_by_row)) in args.iter().zip_eq(order_by_rows).enumerate() {
1087        let last_value = match &window_frame.end_bound {
1088            WindowFrameBound::CurrentRow => match &window_frame.units {
1089                // Always return the current value when in ROWS mode
1090                WindowFrameUnits::Rows => *current_datum,
1091                WindowFrameUnits::Range => {
1092                    // When in RANGE mode, return the last value of the peer group
1093                    // The peer group is the group of rows with the same ORDER BY value
1094                    // Note: Range is only supported for the default window frame (RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW),
1095                    // which is why it does not appear in the other branches
1096                    let target_idx = order_by_rows[idx..]
1097                        .iter()
1098                        .enumerate()
1099                        .take_while(|(_, row)| *row == order_by_row)
1100                        .last()
1101                        .unwrap()
1102                        .0
1103                        + idx;
1104                    args[target_idx]
1105                }
1106                // GROUPS is not supported, and forbidden during planning
1107                WindowFrameUnits::Groups => unreachable!(),
1108            },
1109            WindowFrameBound::UnboundedFollowing => {
1110                if let WindowFrameBound::OffsetFollowing(start_offset) = &window_frame.start_bound {
1111                    let start_offset = usize::cast_from(*start_offset);
1112
1113                    // If the frame starts after the last row of the window, return null
1114                    if idx + start_offset > length - 1 {
1115                        Datum::Null
1116                    } else {
1117                        args[length - 1]
1118                    }
1119                } else {
1120                    args[length - 1]
1121                }
1122            }
1123            WindowFrameBound::OffsetFollowing(offset) => {
1124                let end_offset = usize::cast_from(*offset);
1125                let end_idx = idx.saturating_add(end_offset);
1126                if let WindowFrameBound::OffsetFollowing(start_offset) = &window_frame.start_bound {
1127                    let start_offset = usize::cast_from(*start_offset);
1128                    let start_idx = idx.saturating_add(start_offset);
1129
1130                    // If the frame is empty or starts after the last row of the window, return null
1131                    if end_offset < start_offset || start_idx >= length {
1132                        Datum::Null
1133                    } else {
1134                        // Return the last valid element in the window
1135                        args.get(end_idx).unwrap_or(&args[length - 1]).clone()
1136                    }
1137                } else {
1138                    args.get(end_idx).unwrap_or(&args[length - 1]).clone()
1139                }
1140            }
1141            WindowFrameBound::OffsetPreceding(offset) => {
1142                let end_offset = usize::cast_from(*offset);
1143                let end_idx = idx.saturating_sub(end_offset);
1144                if idx < end_offset {
1145                    // If the frame ends before the first row, return null
1146                    Datum::Null
1147                } else if let WindowFrameBound::OffsetPreceding(start_offset) =
1148                    &window_frame.start_bound
1149                {
1150                    // If the frame is empty, return null
1151                    if offset > start_offset {
1152                        Datum::Null
1153                    } else {
1154                        args[end_idx]
1155                    }
1156                } else {
1157                    args[end_idx]
1158                }
1159            }
1160            // Forbidden during planning
1161            WindowFrameBound::UnboundedPreceding => unreachable!(),
1162        };
1163        results.push(last_value);
1164    }
1165    results
1166}
1167
1168/// Executes `FusedValueWindowFunc` on a reduction group.
1169/// The expected input is in the format of `[((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)]`
1170/// where `Args1`, `Args2`, are the arguments of each of the fused functions. For functions that
1171/// have only a single argument (first_value/last_value), these are simple values. For functions
1172/// that have multiple arguments (lag/lead), these are also records.
1173fn fused_value_window_func<'a, I>(
1174    input_datums: I,
1175    callers_temp_storage: &'a RowArena,
1176    funcs: &Vec<AggregateFunc>,
1177    order_by: &Vec<ColumnOrder>,
1178) -> Datum<'a>
1179where
1180    I: IntoIterator<Item = Datum<'a>>,
1181{
1182    let temp_storage = RowArena::new();
1183    let iter = fused_value_window_func_no_list(input_datums, &temp_storage, funcs, order_by);
1184    callers_temp_storage.make_datum(|packer| {
1185        packer.push_list(iter);
1186    })
1187}
1188
1189/// Like `fused_value_window_func`, but doesn't perform the final wrapping in a list, returning an
1190/// Iterator instead.
1191fn fused_value_window_func_no_list<'a: 'b, 'b, I>(
1192    input_datums: I,
1193    callers_temp_storage: &'b RowArena,
1194    funcs: &Vec<AggregateFunc>,
1195    order_by: &Vec<ColumnOrder>,
1196) -> impl Iterator<Item = Datum<'b>>
1197where
1198    I: IntoIterator<Item = Datum<'a>>,
1199{
1200    let has_last_value = funcs
1201        .iter()
1202        .any(|f| matches!(f, AggregateFunc::LastValue { .. }));
1203
1204    let input_datums_with_ranks = order_aggregate_datums_with_rank(input_datums, order_by);
1205
1206    let size_hint = input_datums_with_ranks.size_hint().0;
1207    let mut encoded_argsss = vec![Vec::with_capacity(size_hint); funcs.len()];
1208    let mut original_rows = Vec::with_capacity(size_hint);
1209    let mut order_by_rows = Vec::with_capacity(size_hint);
1210    for (d, order_by_row) in input_datums_with_ranks {
1211        let mut iter = d.unwrap_list().iter();
1212        let original_row = iter.next().unwrap();
1213        original_rows.push(original_row);
1214        let mut argss_iter = iter.next().unwrap().unwrap_list().iter();
1215        for i in 0..funcs.len() {
1216            let encoded_args = argss_iter.next().unwrap();
1217            encoded_argsss[i].push(encoded_args);
1218        }
1219        if has_last_value {
1220            order_by_rows.push(order_by_row);
1221        }
1222    }
1223
1224    let mut results_per_row = vec![Vec::with_capacity(funcs.len()); original_rows.len()];
1225    for (func, encoded_argss) in funcs.iter().zip_eq(encoded_argsss) {
1226        let results = match func {
1227            AggregateFunc::LagLead {
1228                order_by: inner_order_by,
1229                lag_lead,
1230                ignore_nulls,
1231            } => {
1232                assert_eq!(order_by, inner_order_by);
1233                let unwrapped_argss = encoded_argss
1234                    .into_iter()
1235                    .map(|encoded_args| unwrap_lag_lead_encoded_args(encoded_args))
1236                    .collect();
1237                lag_lead_inner(unwrapped_argss, lag_lead, ignore_nulls)
1238            }
1239            AggregateFunc::FirstValue {
1240                order_by: inner_order_by,
1241                window_frame,
1242            } => {
1243                assert_eq!(order_by, inner_order_by);
1244                // (No unwrapping to do on the args here, because there is only 1 arg, so it's not
1245                // wrapped into a record.)
1246                first_value_inner(encoded_argss, window_frame)
1247            }
1248            AggregateFunc::LastValue {
1249                order_by: inner_order_by,
1250                window_frame,
1251            } => {
1252                assert_eq!(order_by, inner_order_by);
1253                // (No unwrapping to do on the args here, because there is only 1 arg, so it's not
1254                // wrapped into a record.)
1255                last_value_inner(encoded_argss, &order_by_rows, window_frame)
1256            }
1257            _ => panic!("unknown window function in FusedValueWindowFunc"),
1258        };
1259        for (results, result) in results_per_row.iter_mut().zip_eq(results) {
1260            results.push(result);
1261        }
1262    }
1263
1264    callers_temp_storage.reserve(2 * original_rows.len());
1265    results_per_row
1266        .into_iter()
1267        .enumerate()
1268        .map(move |(i, results)| {
1269            callers_temp_storage.make_datum(|packer| {
1270                packer.push_list_with(|packer| {
1271                    packer
1272                        .push(callers_temp_storage.make_datum(|packer| packer.push_list(results)));
1273                    packer.push(original_rows[i]);
1274                });
1275            })
1276        })
1277}
1278
1279/// `input_datums` is an entire window partition.
1280/// The expected input is in the format of `[((OriginalRow, InputValue), OrderByExprs...)]`
1281/// See also in the comment in `window_func_applied_to`.
1282///
1283/// `wrapped_aggregate`: e.g., for `sum(...) OVER (...)`, this is the `sum(...)`.
1284///
1285/// Note that this `order_by` doesn't have expressions, only `ColumnOrder`s. For an explanation,
1286/// see the comment on `WindowExprType`.
1287fn window_aggr<'a, I, A>(
1288    input_datums: I,
1289    callers_temp_storage: &'a RowArena,
1290    wrapped_aggregate: &AggregateFunc,
1291    order_by: &[ColumnOrder],
1292    window_frame: &WindowFrame,
1293) -> Datum<'a>
1294where
1295    I: IntoIterator<Item = Datum<'a>>,
1296    A: OneByOneAggr,
1297{
1298    let temp_storage = RowArena::new();
1299    let iter = window_aggr_no_list::<I, A>(
1300        input_datums,
1301        &temp_storage,
1302        wrapped_aggregate,
1303        order_by,
1304        window_frame,
1305    );
1306    callers_temp_storage.make_datum(|packer| {
1307        packer.push_list(iter);
1308    })
1309}
1310
1311/// Like `window_aggr`, but doesn't perform the final wrapping in a list, returning an Iterator
1312/// instead.
1313fn window_aggr_no_list<'a: 'b, 'b, I, A>(
1314    input_datums: I,
1315    callers_temp_storage: &'b RowArena,
1316    wrapped_aggregate: &AggregateFunc,
1317    order_by: &[ColumnOrder],
1318    window_frame: &WindowFrame,
1319) -> impl Iterator<Item = Datum<'b>>
1320where
1321    I: IntoIterator<Item = Datum<'a>>,
1322    A: OneByOneAggr,
1323{
1324    // Sort the datums according to the ORDER BY expressions and return the ((OriginalRow, InputValue), OrderByRow) record
1325    // The OrderByRow is kept around because it is required to compute the peer groups in RANGE mode
1326    let datums = order_aggregate_datums_with_rank(input_datums, order_by);
1327
1328    // Decode the input (OriginalRow, InputValue) into separate datums, while keeping the OrderByRow
1329    let size_hint = datums.size_hint().0;
1330    let mut args: Vec<Datum> = Vec::with_capacity(size_hint);
1331    let mut original_rows: Vec<Datum> = Vec::with_capacity(size_hint);
1332    let mut order_by_rows = Vec::with_capacity(size_hint);
1333    for (d, order_by_row) in datums.into_iter() {
1334        let mut iter = d.unwrap_list().iter();
1335        let original_row = iter.next().unwrap();
1336        let arg = iter.next().unwrap();
1337        order_by_rows.push(order_by_row);
1338        original_rows.push(original_row);
1339        args.push(arg);
1340    }
1341
1342    let results = window_aggr_inner::<A>(
1343        args,
1344        &order_by_rows,
1345        wrapped_aggregate,
1346        order_by,
1347        window_frame,
1348        callers_temp_storage,
1349    );
1350
1351    callers_temp_storage.reserve(results.len());
1352    results
1353        .into_iter()
1354        .zip_eq(original_rows)
1355        .map(|(result_value, original_row)| {
1356            callers_temp_storage.make_datum(|packer| {
1357                packer.push_list_with(|packer| {
1358                    packer.push(result_value);
1359                    packer.push(original_row);
1360                });
1361            })
1362        })
1363}
1364
1365fn window_aggr_inner<'a, A>(
1366    mut args: Vec<Datum<'a>>,
1367    order_by_rows: &Vec<Row>,
1368    wrapped_aggregate: &AggregateFunc,
1369    order_by: &[ColumnOrder],
1370    window_frame: &WindowFrame,
1371    temp_storage: &'a RowArena,
1372) -> Vec<Datum<'a>>
1373where
1374    A: OneByOneAggr,
1375{
1376    let length = args.len();
1377    let mut result: Vec<Datum> = Vec::with_capacity(length);
1378
1379    // In this degenerate case, all results would be `wrapped_aggregate.default()` (usually null).
1380    // However, this currently can't happen, because
1381    // - Groups frame mode is currently not supported;
1382    // - Range frame mode is currently supported only for the default frame, which includes the
1383    //   current row.
1384    soft_assert_or_log!(
1385        !((matches!(window_frame.units, WindowFrameUnits::Groups)
1386            || matches!(window_frame.units, WindowFrameUnits::Range))
1387            && !window_frame.includes_current_row()),
1388        "window frame without current row"
1389    );
1390
1391    if (matches!(
1392        window_frame.start_bound,
1393        WindowFrameBound::UnboundedPreceding
1394    ) && matches!(window_frame.end_bound, WindowFrameBound::UnboundedFollowing))
1395        || (order_by.is_empty()
1396            && (matches!(window_frame.units, WindowFrameUnits::Groups)
1397                || matches!(window_frame.units, WindowFrameUnits::Range))
1398            && window_frame.includes_current_row())
1399    {
1400        // Either
1401        //  - UNBOUNDED frame in both directions, or
1402        //  - There is no ORDER BY and the frame is such that the current peer group is included.
1403        //    (The current peer group will be the whole partition if there is no ORDER BY.)
1404        // We simply need to compute the aggregate once, on the entire partition, and each input
1405        // row will get this one aggregate value as result.
1406        let result_value =
1407            wrapped_aggregate.eval(args.into_iter().map(|d| (d, Diff::ONE)), temp_storage);
1408        // Every row will get the above aggregate as result.
1409        for _ in 0..length {
1410            result.push(result_value);
1411        }
1412    } else {
1413        fn rows_between_unbounded_preceding_and_current_row<'a, A>(
1414            args: Vec<Datum<'a>>,
1415            result: &mut Vec<Datum<'a>>,
1416            mut one_by_one_aggr: A,
1417            temp_storage: &'a RowArena,
1418        ) where
1419            A: OneByOneAggr,
1420        {
1421            for current_arg in args.into_iter() {
1422                one_by_one_aggr.give(&current_arg);
1423                let result_value = one_by_one_aggr.get_current_aggregate(temp_storage);
1424                result.push(result_value);
1425            }
1426        }
1427
1428        fn groups_between_unbounded_preceding_and_current_row<'a, A>(
1429            args: Vec<Datum<'a>>,
1430            order_by_rows: &Vec<Row>,
1431            result: &mut Vec<Datum<'a>>,
1432            mut one_by_one_aggr: A,
1433            temp_storage: &'a RowArena,
1434        ) where
1435            A: OneByOneAggr,
1436        {
1437            let mut peer_group_start = 0;
1438            while peer_group_start < args.len() {
1439                // Find the boundaries of the current peer group.
1440                // peer_group_start will point to the first element of the peer group,
1441                // peer_group_end will point to _just after_ the last element of the peer group.
1442                let mut peer_group_end = peer_group_start + 1;
1443                while peer_group_end < args.len()
1444                    && order_by_rows[peer_group_start] == order_by_rows[peer_group_end]
1445                {
1446                    // The peer group goes on while the OrderByRows not differ.
1447                    peer_group_end += 1;
1448                }
1449                // Let's compute the aggregate (which will be the same for all records in this
1450                // peer group).
1451                for current_arg in args[peer_group_start..peer_group_end].iter() {
1452                    one_by_one_aggr.give(current_arg);
1453                }
1454                let agg_for_peer_group = one_by_one_aggr.get_current_aggregate(temp_storage);
1455                // Put the above aggregate into each record in the peer group.
1456                for _ in args[peer_group_start..peer_group_end].iter() {
1457                    result.push(agg_for_peer_group);
1458                }
1459                // Point to the start of the next peer group.
1460                peer_group_start = peer_group_end;
1461            }
1462        }
1463
1464        fn rows_between_offset_and_offset<'a>(
1465            args: Vec<Datum<'a>>,
1466            result: &mut Vec<Datum<'a>>,
1467            wrapped_aggregate: &AggregateFunc,
1468            temp_storage: &'a RowArena,
1469            offset_start: i64,
1470            offset_end: i64,
1471        ) {
1472            let len = args
1473                .len()
1474                .to_i64()
1475                .expect("window partition's len should fit into i64");
1476            for i in 0..len {
1477                let i = i.to_i64().expect("window partition shouldn't be super big");
1478                // Trim the start of the frame to make it not reach over the start of the window
1479                // partition.
1480                let frame_start = max(i + offset_start, 0)
1481                    .to_usize()
1482                    .expect("The max made sure it's not negative");
1483                // Trim the end of the frame to make it not reach over the end of the window
1484                // partition.
1485                let frame_end = min(i + offset_end, len - 1).to_usize();
1486                match frame_end {
1487                    Some(frame_end) => {
1488                        if frame_start <= frame_end {
1489                            // Compute the aggregate on the frame.
1490                            // TODO:
1491                            // This implementation is quite slow if the frame is large: we do an
1492                            // inner loop over the entire frame, and compute the aggregate from
1493                            // scratch. We could do better:
1494                            //  - For invertible aggregations we could do a rolling aggregation.
1495                            //  - There are various tricks for min/max as well, making use of either
1496                            //    the fixed size of the window, or that we are not retracting
1497                            //    arbitrary elements but doing queue operations. E.g., see
1498                            //    http://codercareer.blogspot.com/2012/02/no-33-maximums-in-sliding-windows.html
1499                            let frame_values = args[frame_start..=frame_end]
1500                                .iter()
1501                                .map(|d| (*d, Diff::ONE));
1502                            let result_value = wrapped_aggregate.eval(frame_values, temp_storage);
1503                            result.push(result_value);
1504                        } else {
1505                            // frame_start > frame_end, so this is an empty frame.
1506                            let result_value = wrapped_aggregate.default();
1507                            result.push(result_value);
1508                        }
1509                    }
1510                    None => {
1511                        // frame_end would be negative, so this is an empty frame.
1512                        let result_value = wrapped_aggregate.default();
1513                        result.push(result_value);
1514                    }
1515                }
1516            }
1517        }
1518
1519        match (
1520            &window_frame.units,
1521            &window_frame.start_bound,
1522            &window_frame.end_bound,
1523        ) {
1524            // Cases where one edge of the frame is CurrentRow.
1525            // Note that these cases could be merged into the more general cases below where one
1526            // edge is some offset (with offset = 0), but the CurrentRow cases probably cover 95%
1527            // of user queries, so let's make this simple and fast.
1528            (Rows, UnboundedPreceding, CurrentRow) => {
1529                rows_between_unbounded_preceding_and_current_row::<A>(
1530                    args,
1531                    &mut result,
1532                    A::new(wrapped_aggregate, false),
1533                    temp_storage,
1534                );
1535            }
1536            (Rows, CurrentRow, UnboundedFollowing) => {
1537                // Same as above, but reverse.
1538                args.reverse();
1539                rows_between_unbounded_preceding_and_current_row::<A>(
1540                    args,
1541                    &mut result,
1542                    A::new(wrapped_aggregate, true),
1543                    temp_storage,
1544                );
1545                result.reverse();
1546            }
1547            (Range, UnboundedPreceding, CurrentRow) => {
1548                // Note that for the default frame, the RANGE frame mode is identical to the GROUPS
1549                // frame mode.
1550                groups_between_unbounded_preceding_and_current_row::<A>(
1551                    args,
1552                    order_by_rows,
1553                    &mut result,
1554                    A::new(wrapped_aggregate, false),
1555                    temp_storage,
1556                );
1557            }
1558            // The next several cases all call `rows_between_offset_and_offset`. Note that the
1559            // offset passed to `rows_between_offset_and_offset` should be negated when it's
1560            // PRECEDING.
1561            (Rows, OffsetPreceding(start_prec), OffsetPreceding(end_prec)) => {
1562                let start_prec = start_prec.to_i64().expect(
1563                    "window frame start OFFSET shouldn't be super big (the planning ensured this)",
1564                );
1565                let end_prec = end_prec.to_i64().expect(
1566                    "window frame end OFFSET shouldn't be super big (the planning ensured this)",
1567                );
1568                rows_between_offset_and_offset(
1569                    args,
1570                    &mut result,
1571                    wrapped_aggregate,
1572                    temp_storage,
1573                    -start_prec,
1574                    -end_prec,
1575                );
1576            }
1577            (Rows, OffsetPreceding(start_prec), OffsetFollowing(end_fol)) => {
1578                let start_prec = start_prec.to_i64().expect(
1579                    "window frame start OFFSET shouldn't be super big (the planning ensured this)",
1580                );
1581                let end_fol = end_fol.to_i64().expect(
1582                    "window frame end OFFSET shouldn't be super big (the planning ensured this)",
1583                );
1584                rows_between_offset_and_offset(
1585                    args,
1586                    &mut result,
1587                    wrapped_aggregate,
1588                    temp_storage,
1589                    -start_prec,
1590                    end_fol,
1591                );
1592            }
1593            (Rows, OffsetFollowing(start_fol), OffsetFollowing(end_fol)) => {
1594                let start_fol = start_fol.to_i64().expect(
1595                    "window frame start OFFSET shouldn't be super big (the planning ensured this)",
1596                );
1597                let end_fol = end_fol.to_i64().expect(
1598                    "window frame end OFFSET shouldn't be super big (the planning ensured this)",
1599                );
1600                rows_between_offset_and_offset(
1601                    args,
1602                    &mut result,
1603                    wrapped_aggregate,
1604                    temp_storage,
1605                    start_fol,
1606                    end_fol,
1607                );
1608            }
1609            (Rows, OffsetFollowing(_), OffsetPreceding(_)) => {
1610                unreachable!() // The planning ensured that this nonsensical case can't happen
1611            }
1612            (Rows, OffsetPreceding(start_prec), CurrentRow) => {
1613                let start_prec = start_prec.to_i64().expect(
1614                    "window frame start OFFSET shouldn't be super big (the planning ensured this)",
1615                );
1616                let end_fol = 0;
1617                rows_between_offset_and_offset(
1618                    args,
1619                    &mut result,
1620                    wrapped_aggregate,
1621                    temp_storage,
1622                    -start_prec,
1623                    end_fol,
1624                );
1625            }
1626            (Rows, CurrentRow, OffsetFollowing(end_fol)) => {
1627                let start_fol = 0;
1628                let end_fol = end_fol.to_i64().expect(
1629                    "window frame end OFFSET shouldn't be super big (the planning ensured this)",
1630                );
1631                rows_between_offset_and_offset(
1632                    args,
1633                    &mut result,
1634                    wrapped_aggregate,
1635                    temp_storage,
1636                    start_fol,
1637                    end_fol,
1638                );
1639            }
1640            (Rows, CurrentRow, CurrentRow) => {
1641                // We could have a more efficient implementation for this, but this is probably
1642                // super rare. (Might be more common with RANGE or GROUPS frame mode, though!)
1643                let start_fol = 0;
1644                let end_fol = 0;
1645                rows_between_offset_and_offset(
1646                    args,
1647                    &mut result,
1648                    wrapped_aggregate,
1649                    temp_storage,
1650                    start_fol,
1651                    end_fol,
1652                );
1653            }
1654            (Rows, CurrentRow, OffsetPreceding(_))
1655            | (Rows, UnboundedFollowing, _)
1656            | (Rows, _, UnboundedPreceding)
1657            | (Rows, OffsetFollowing(..), CurrentRow) => {
1658                unreachable!() // The planning ensured that these nonsensical cases can't happen
1659            }
1660            (Rows, UnboundedPreceding, UnboundedFollowing) => {
1661                // This is handled by the complicated if condition near the beginning of this
1662                // function.
1663                unreachable!()
1664            }
1665            (Rows, UnboundedPreceding, OffsetPreceding(_))
1666            | (Rows, UnboundedPreceding, OffsetFollowing(_))
1667            | (Rows, OffsetPreceding(..), UnboundedFollowing)
1668            | (Rows, OffsetFollowing(..), UnboundedFollowing) => {
1669                // Unsupported. Bail in the planner.
1670                // https://github.com/MaterializeInc/database-issues/issues/6720
1671                unreachable!()
1672            }
1673            (Range, _, _) => {
1674                // Unsupported.
1675                // The planner doesn't allow Range frame mode for now (except for the default
1676                // frame), see https://github.com/MaterializeInc/database-issues/issues/6585
1677                // Note that it would be easy to handle (Range, CurrentRow, UnboundedFollowing):
1678                // it would be similar to (Rows, CurrentRow, UnboundedFollowing), but would call
1679                // groups_between_unbounded_preceding_current_row.
1680                unreachable!()
1681            }
1682            (Groups, _, _) => {
1683                // Unsupported.
1684                // The planner doesn't allow Groups frame mode for now, see
1685                // https://github.com/MaterializeInc/database-issues/issues/6588
1686                unreachable!()
1687            }
1688        }
1689    }
1690
1691    result
1692}
1693
1694/// Computes a bundle of fused window aggregations.
1695/// The input is similar to `window_aggr`, but `InputValue` is not just a single value, but a record
1696/// where each component is the input to one of the aggregations.
1697fn fused_window_aggr<'a, I, A>(
1698    input_datums: I,
1699    callers_temp_storage: &'a RowArena,
1700    wrapped_aggregates: &Vec<AggregateFunc>,
1701    order_by: &Vec<ColumnOrder>,
1702    window_frame: &WindowFrame,
1703) -> Datum<'a>
1704where
1705    I: IntoIterator<Item = Datum<'a>>,
1706    A: OneByOneAggr,
1707{
1708    let temp_storage = RowArena::new();
1709    let iter = fused_window_aggr_no_list::<_, A>(
1710        input_datums,
1711        &temp_storage,
1712        wrapped_aggregates,
1713        order_by,
1714        window_frame,
1715    );
1716    callers_temp_storage.make_datum(|packer| {
1717        packer.push_list(iter);
1718    })
1719}
1720
1721/// Like `fused_window_aggr`, but doesn't perform the final wrapping in a list, returning an
1722/// Iterator instead.
1723fn fused_window_aggr_no_list<'a: 'b, 'b, I, A>(
1724    input_datums: I,
1725    callers_temp_storage: &'b RowArena,
1726    wrapped_aggregates: &Vec<AggregateFunc>,
1727    order_by: &Vec<ColumnOrder>,
1728    window_frame: &WindowFrame,
1729) -> impl Iterator<Item = Datum<'b>>
1730where
1731    I: IntoIterator<Item = Datum<'a>>,
1732    A: OneByOneAggr,
1733{
1734    // Sort the datums according to the ORDER BY expressions and return the ((OriginalRow, InputValue), OrderByRow) record
1735    // The OrderByRow is kept around because it is required to compute the peer groups in RANGE mode
1736    let datums = order_aggregate_datums_with_rank(input_datums, order_by);
1737
1738    let size_hint = datums.size_hint().0;
1739    let mut argss = vec![Vec::with_capacity(size_hint); wrapped_aggregates.len()];
1740    let mut original_rows = Vec::with_capacity(size_hint);
1741    let mut order_by_rows = Vec::with_capacity(size_hint);
1742    for (d, order_by_row) in datums {
1743        let mut iter = d.unwrap_list().iter();
1744        let original_row = iter.next().unwrap();
1745        original_rows.push(original_row);
1746        let args_iter = iter.next().unwrap().unwrap_list().iter();
1747        // Push each argument into the respective list
1748        for (args, arg) in argss.iter_mut().zip_eq(args_iter) {
1749            args.push(arg);
1750        }
1751        order_by_rows.push(order_by_row);
1752    }
1753
1754    let mut results_per_row =
1755        vec![Vec::with_capacity(wrapped_aggregates.len()); original_rows.len()];
1756    for (wrapped_aggr, args) in wrapped_aggregates.iter().zip_eq(argss) {
1757        let results = window_aggr_inner::<A>(
1758            args,
1759            &order_by_rows,
1760            wrapped_aggr,
1761            order_by,
1762            window_frame,
1763            callers_temp_storage,
1764        );
1765        for (results, result) in results_per_row.iter_mut().zip_eq(results) {
1766            results.push(result);
1767        }
1768    }
1769
1770    callers_temp_storage.reserve(2 * original_rows.len());
1771    results_per_row
1772        .into_iter()
1773        .enumerate()
1774        .map(move |(i, results)| {
1775            callers_temp_storage.make_datum(|packer| {
1776                packer.push_list_with(|packer| {
1777                    packer
1778                        .push(callers_temp_storage.make_datum(|packer| packer.push_list(results)));
1779                    packer.push(original_rows[i]);
1780                });
1781            })
1782        })
1783}
1784
1785/// An implementation of an aggregation where we can send in the input elements one-by-one, and
1786/// can also ask the current aggregate at any moment. (This just delegates to other aggregation
1787/// evaluation approaches.)
1788pub trait OneByOneAggr {
1789    /// The `reverse` parameter makes the aggregations process input elements in reverse order.
1790    /// This has an effect only for non-commutative aggregations, e.g. `list_agg`. These are
1791    /// currently only some of the Basic aggregations. (Basic aggregations are handled by
1792    /// `NaiveOneByOneAggr`).
1793    fn new(agg: &AggregateFunc, reverse: bool) -> Self;
1794    /// Pushes one input element into the aggregation.
1795    fn give(&mut self, d: &Datum);
1796    /// Returns the value of the aggregate computed on the given values so far.
1797    fn get_current_aggregate<'a>(&self, temp_storage: &'a RowArena) -> Datum<'a>;
1798}
1799
1800/// Naive implementation of [OneByOneAggr], suitable for stuff like const folding, but too slow for
1801/// rendering. This relies only on infrastructure available in `mz-expr`. It simply saves all the
1802/// given input, and calls the given [AggregateFunc]'s `eval` method when asked about the current
1803/// aggregate. (For Accumulable and Hierarchical aggregations, the rendering has more efficient
1804/// implementations, but for Basic aggregations even the rendering uses this naive implementation.)
1805#[derive(Debug)]
1806pub struct NaiveOneByOneAggr {
1807    agg: AggregateFunc,
1808    input: Vec<Row>,
1809    reverse: bool,
1810}
1811
1812impl OneByOneAggr for NaiveOneByOneAggr {
1813    fn new(agg: &AggregateFunc, reverse: bool) -> Self {
1814        NaiveOneByOneAggr {
1815            agg: agg.clone(),
1816            input: Vec::new(),
1817            reverse,
1818        }
1819    }
1820
1821    fn give(&mut self, d: &Datum) {
1822        let mut row = Row::default();
1823        row.packer().push(d);
1824        self.input.push(row);
1825    }
1826
1827    fn get_current_aggregate<'a>(&self, temp_storage: &'a RowArena) -> Datum<'a> {
1828        temp_storage.make_datum(|packer| {
1829            packer.push(if !self.reverse {
1830                self.agg.eval(
1831                    self.input.iter().map(|r| (r.unpack_first(), Diff::ONE)),
1832                    temp_storage,
1833                )
1834            } else {
1835                self.agg.eval(
1836                    self.input
1837                        .iter()
1838                        .rev()
1839                        .map(|r| (r.unpack_first(), Diff::ONE)),
1840                    temp_storage,
1841                )
1842            });
1843        })
1844    }
1845}
1846
1847/// Identify whether the given aggregate function is Lag or Lead, since they share
1848/// implementations.
1849#[derive(
1850    Clone,
1851    Debug,
1852    Eq,
1853    PartialEq,
1854    Ord,
1855    PartialOrd,
1856    Serialize,
1857    Deserialize,
1858    Hash,
1859    MzReflect
1860)]
1861pub enum LagLeadType {
1862    Lag,
1863    Lead,
1864}
1865
1866#[derive(
1867    Clone,
1868    Debug,
1869    Eq,
1870    PartialEq,
1871    Ord,
1872    PartialOrd,
1873    Serialize,
1874    Deserialize,
1875    Hash,
1876    MzReflect
1877)]
1878pub enum AggregateFunc {
1879    MaxNumeric,
1880    MaxInt16,
1881    MaxInt32,
1882    MaxInt64,
1883    MaxUInt16,
1884    MaxUInt32,
1885    MaxUInt64,
1886    MaxMzTimestamp,
1887    MaxFloat32,
1888    MaxFloat64,
1889    MaxBool,
1890    MaxString,
1891    MaxDate,
1892    MaxTimestamp,
1893    MaxTimestampTz,
1894    MaxInterval,
1895    MaxTime,
1896    MinNumeric,
1897    MinInt16,
1898    MinInt32,
1899    MinInt64,
1900    MinUInt16,
1901    MinUInt32,
1902    MinUInt64,
1903    MinMzTimestamp,
1904    MinFloat32,
1905    MinFloat64,
1906    MinBool,
1907    MinString,
1908    MinDate,
1909    MinTimestamp,
1910    MinTimestampTz,
1911    MinInterval,
1912    MinTime,
1913    SumInt16,
1914    SumInt32,
1915    SumInt64,
1916    SumUInt16,
1917    SumUInt32,
1918    SumUInt64,
1919    SumFloat32,
1920    SumFloat64,
1921    SumNumeric,
1922    Count,
1923    Any,
1924    All,
1925    /// Accumulates `Datum::List`s whose first element is a JSON-typed `Datum`s
1926    /// into a JSON list. The other elements are columns used by `order_by`.
1927    ///
1928    /// WARNING: Unlike the `jsonb_agg` function that is exposed by the SQL
1929    /// layer, this function filters out `Datum::Null`, for consistency with
1930    /// the other aggregate functions.
1931    JsonbAgg {
1932        order_by: Vec<ColumnOrder>,
1933    },
1934    /// Zips `Datum::List`s whose first element is a JSON-typed `Datum`s into a
1935    /// JSON map. The other elements are columns used by `order_by`.
1936    ///
1937    /// WARNING: Unlike the `jsonb_object_agg` function that is exposed by the SQL
1938    /// layer, this function filters out `Datum::Null`, for consistency with
1939    /// the other aggregate functions.
1940    JsonbObjectAgg {
1941        order_by: Vec<ColumnOrder>,
1942    },
1943    /// Zips a `Datum::List` whose first element is a `Datum::List` guaranteed
1944    /// to be non-empty and whose len % 2 == 0 into a `Datum::Map`. The other
1945    /// elements are columns used by `order_by`.
1946    MapAgg {
1947        order_by: Vec<ColumnOrder>,
1948        value_type: SqlScalarType,
1949    },
1950    /// Accumulates `Datum::Array`s of `SqlScalarType::Record` whose first element is a `Datum::Array`
1951    /// into a single `Datum::Array` (the remaining fields are used by `order_by`).
1952    ArrayConcat {
1953        order_by: Vec<ColumnOrder>,
1954    },
1955    /// Accumulates `Datum::List`s of `SqlScalarType::Record` whose first field is a `Datum::List`
1956    /// into a single `Datum::List` (the remaining fields are used by `order_by`).
1957    ListConcat {
1958        order_by: Vec<ColumnOrder>,
1959    },
1960    StringAgg {
1961        order_by: Vec<ColumnOrder>,
1962    },
1963    RowNumber {
1964        order_by: Vec<ColumnOrder>,
1965    },
1966    Rank {
1967        order_by: Vec<ColumnOrder>,
1968    },
1969    DenseRank {
1970        order_by: Vec<ColumnOrder>,
1971    },
1972    LagLead {
1973        order_by: Vec<ColumnOrder>,
1974        lag_lead: LagLeadType,
1975        ignore_nulls: bool,
1976    },
1977    FirstValue {
1978        order_by: Vec<ColumnOrder>,
1979        window_frame: WindowFrame,
1980    },
1981    LastValue {
1982        order_by: Vec<ColumnOrder>,
1983        window_frame: WindowFrame,
1984    },
1985    /// Several value window functions fused into one function, to amortize overheads.
1986    FusedValueWindowFunc {
1987        funcs: Vec<AggregateFunc>,
1988        /// Currently, all the fused functions must have the same `order_by`. (We can later
1989        /// eliminate this limitation.)
1990        order_by: Vec<ColumnOrder>,
1991    },
1992    WindowAggregate {
1993        wrapped_aggregate: Box<AggregateFunc>,
1994        order_by: Vec<ColumnOrder>,
1995        window_frame: WindowFrame,
1996    },
1997    FusedWindowAggregate {
1998        wrapped_aggregates: Vec<AggregateFunc>,
1999        order_by: Vec<ColumnOrder>,
2000        window_frame: WindowFrame,
2001    },
2002    /// Accumulates any number of `Datum::Dummy`s into `Datum::Dummy`.
2003    ///
2004    /// Useful for removing an expensive aggregation while maintaining the shape
2005    /// of a reduce operator.
2006    Dummy,
2007}
2008
2009/// Expands an iterator of `(datum, diff)` into one `datum` per unit of `diff`.
2010///
2011/// A non-positive `diff` contributes no copies. This is used by aggregates that
2012/// are sensitive to multiplicity (e.g. `sum`), to recover a flat datum stream
2013/// from the count-aware surface.
2014fn expand_counts<'a, I>(datums: I) -> impl Iterator<Item = Datum<'a>>
2015where
2016    I: IntoIterator<Item = (Datum<'a>, Diff)>,
2017{
2018    datums.into_iter().flat_map(|(datum, diff)| {
2019        let copies = usize::try_from(diff.into_inner()).unwrap_or(0);
2020        std::iter::repeat(datum).take(copies)
2021    })
2022}
2023
2024impl AggregateFunc {
2025    /// Whether this aggregate's result is independent of the multiplicity of its
2026    /// inputs (e.g. `min`/`max`/`any`/`all`).
2027    ///
2028    /// Such aggregates can ignore the `diff` of each input, evaluating over the
2029    /// distinct datums rather than expanding by count. This keeps idempotent
2030    /// reductions linear in the number of distinct inputs.
2031    fn ignores_multiplicity(&self) -> bool {
2032        use AggregateFunc::*;
2033        matches!(
2034            self,
2035            MaxNumeric
2036                | MaxInt16
2037                | MaxInt32
2038                | MaxInt64
2039                | MaxUInt16
2040                | MaxUInt32
2041                | MaxUInt64
2042                | MaxMzTimestamp
2043                | MaxFloat32
2044                | MaxFloat64
2045                | MaxBool
2046                | MaxString
2047                | MaxDate
2048                | MaxTimestamp
2049                | MaxTimestampTz
2050                | MaxInterval
2051                | MaxTime
2052                | MinNumeric
2053                | MinInt16
2054                | MinInt32
2055                | MinInt64
2056                | MinUInt16
2057                | MinUInt32
2058                | MinUInt64
2059                | MinMzTimestamp
2060                | MinFloat32
2061                | MinFloat64
2062                | MinBool
2063                | MinString
2064                | MinDate
2065                | MinTimestamp
2066                | MinTimestampTz
2067                | MinInterval
2068                | MinTime
2069                | Any
2070                | All
2071        )
2072    }
2073
2074    /// Evaluates the aggregate over an iterator of `(datum, diff)` pairs.
2075    ///
2076    /// Each aggregate consumes the multiplicity (`diff`) in whatever way is most
2077    /// efficient: `count` sums the diffs, multiplicity-insensitive aggregates
2078    /// (see `AggregateFunc::ignores_multiplicity`) ignore them, and everything
2079    /// else expands each datum into `diff` copies (see `expand_counts`).
2080    pub fn eval<'a, I>(&self, datums: I, temp_storage: &'a RowArena) -> Datum<'a>
2081    where
2082        I: IntoIterator<Item = (Datum<'a>, Diff)>,
2083    {
2084        // Accumulable aggregates consume multiplicity directly rather than
2085        // expanding each `(datum, diff)` into `diff` copies. The cases handled
2086        // here mirror the dataflow's accumulable reduction (`build_accumulable`
2087        // in `mz_compute::render::reduce`) so that constant folding produces the
2088        // same result the dataflow would. Signed integer sums are folded here;
2089        // unsigned sums are not, because their negative-accumulation case is a
2090        // query error in the dataflow that this `Datum`-returning path cannot
2091        // signal. Floats and numerics use bespoke fixed-point/wide-decimal
2092        // accumulators in the dataflow that `expand_counts` does not reproduce.
2093        match self {
2094            AggregateFunc::Count => count(datums),
2095            AggregateFunc::SumInt16 | AggregateFunc::SumInt32 => {
2096                // `finalize_accum` narrows these to `i64` with wrapping.
2097                sum_signed_int_counted(datums, |accum| {
2098                    #[allow(clippy::as_conversions)]
2099                    let narrowed = accum as i64;
2100                    Datum::Int64(narrowed)
2101                })
2102            }
2103            AggregateFunc::SumInt64 => sum_signed_int_counted(datums, Datum::from),
2104            _ if self.ignores_multiplicity() => {
2105                self.eval_datums(datums.into_iter().map(|(datum, _diff)| datum), temp_storage)
2106            }
2107            _ => self.eval_datums(expand_counts(datums), temp_storage),
2108        }
2109    }
2110
2111    /// Evaluates the aggregate over a flat iterator of datums, ignoring multiplicity.
2112    fn eval_datums<'a, I>(&self, datums: I, temp_storage: &'a RowArena) -> Datum<'a>
2113    where
2114        I: IntoIterator<Item = Datum<'a>>,
2115    {
2116        match self {
2117            AggregateFunc::MaxNumeric => {
2118                max_datum::<'a, I, OrderedDecimal<numeric::Numeric>>(datums)
2119            }
2120            AggregateFunc::MaxInt16 => max_datum::<'a, I, i16>(datums),
2121            AggregateFunc::MaxInt32 => max_datum::<'a, I, i32>(datums),
2122            AggregateFunc::MaxInt64 => max_datum::<'a, I, i64>(datums),
2123            AggregateFunc::MaxUInt16 => max_datum::<'a, I, u16>(datums),
2124            AggregateFunc::MaxUInt32 => max_datum::<'a, I, u32>(datums),
2125            AggregateFunc::MaxUInt64 => max_datum::<'a, I, u64>(datums),
2126            AggregateFunc::MaxMzTimestamp => max_datum::<'a, I, mz_repr::Timestamp>(datums),
2127            AggregateFunc::MaxFloat32 => max_datum::<'a, I, OrderedFloat<f32>>(datums),
2128            AggregateFunc::MaxFloat64 => max_datum::<'a, I, OrderedFloat<f64>>(datums),
2129            AggregateFunc::MaxBool => max_datum::<'a, I, bool>(datums),
2130            AggregateFunc::MaxString => max_string(datums),
2131            AggregateFunc::MaxDate => max_datum::<'a, I, Date>(datums),
2132            AggregateFunc::MaxTimestamp => {
2133                max_datum::<'a, I, CheckedTimestamp<NaiveDateTime>>(datums)
2134            }
2135            AggregateFunc::MaxTimestampTz => {
2136                max_datum::<'a, I, CheckedTimestamp<DateTime<Utc>>>(datums)
2137            }
2138            AggregateFunc::MaxInterval => max_datum::<'a, I, Interval>(datums),
2139            AggregateFunc::MaxTime => max_datum::<'a, I, NaiveTime>(datums),
2140            AggregateFunc::MinNumeric => {
2141                min_datum::<'a, I, OrderedDecimal<numeric::Numeric>>(datums)
2142            }
2143            AggregateFunc::MinInt16 => min_datum::<'a, I, i16>(datums),
2144            AggregateFunc::MinInt32 => min_datum::<'a, I, i32>(datums),
2145            AggregateFunc::MinInt64 => min_datum::<'a, I, i64>(datums),
2146            AggregateFunc::MinUInt16 => min_datum::<'a, I, u16>(datums),
2147            AggregateFunc::MinUInt32 => min_datum::<'a, I, u32>(datums),
2148            AggregateFunc::MinUInt64 => min_datum::<'a, I, u64>(datums),
2149            AggregateFunc::MinMzTimestamp => min_datum::<'a, I, mz_repr::Timestamp>(datums),
2150            AggregateFunc::MinFloat32 => min_datum::<'a, I, OrderedFloat<f32>>(datums),
2151            AggregateFunc::MinFloat64 => min_datum::<'a, I, OrderedFloat<f64>>(datums),
2152            AggregateFunc::MinBool => min_datum::<'a, I, bool>(datums),
2153            AggregateFunc::MinString => min_string(datums),
2154            AggregateFunc::MinDate => min_datum::<'a, I, Date>(datums),
2155            AggregateFunc::MinTimestamp => {
2156                min_datum::<'a, I, CheckedTimestamp<NaiveDateTime>>(datums)
2157            }
2158            AggregateFunc::MinTimestampTz => {
2159                min_datum::<'a, I, CheckedTimestamp<DateTime<Utc>>>(datums)
2160            }
2161            AggregateFunc::MinInterval => min_datum::<'a, I, Interval>(datums),
2162            AggregateFunc::MinTime => min_datum::<'a, I, NaiveTime>(datums),
2163            AggregateFunc::SumInt16 => sum_datum::<'a, I, i16, i64>(datums),
2164            AggregateFunc::SumInt32 => sum_datum::<'a, I, i32, i64>(datums),
2165            AggregateFunc::SumInt64 => sum_datum::<'a, I, i64, i128>(datums),
2166            AggregateFunc::SumUInt16 => sum_datum::<'a, I, u16, u64>(datums),
2167            AggregateFunc::SumUInt32 => sum_datum::<'a, I, u32, u64>(datums),
2168            AggregateFunc::SumUInt64 => sum_datum::<'a, I, u64, u128>(datums),
2169            AggregateFunc::SumFloat32 => sum_datum::<'a, I, f32, f32>(datums),
2170            AggregateFunc::SumFloat64 => sum_datum::<'a, I, f64, f64>(datums),
2171            AggregateFunc::SumNumeric => sum_numeric(datums),
2172            AggregateFunc::Count => unreachable!("Count is handled in `eval`"),
2173            AggregateFunc::Any => any(datums),
2174            AggregateFunc::All => all(datums),
2175            AggregateFunc::JsonbAgg { order_by } => jsonb_agg(datums, temp_storage, order_by),
2176            AggregateFunc::MapAgg { order_by, .. } | AggregateFunc::JsonbObjectAgg { order_by } => {
2177                dict_agg(datums, temp_storage, order_by)
2178            }
2179            AggregateFunc::ArrayConcat { order_by } => array_concat(datums, temp_storage, order_by),
2180            AggregateFunc::ListConcat { order_by } => list_concat(datums, temp_storage, order_by),
2181            AggregateFunc::StringAgg { order_by } => string_agg(datums, temp_storage, order_by),
2182            AggregateFunc::RowNumber { order_by } => row_number(datums, temp_storage, order_by),
2183            AggregateFunc::Rank { order_by } => rank(datums, temp_storage, order_by),
2184            AggregateFunc::DenseRank { order_by } => dense_rank(datums, temp_storage, order_by),
2185            AggregateFunc::LagLead {
2186                order_by,
2187                lag_lead: lag_lead_type,
2188                ignore_nulls,
2189            } => lag_lead(datums, temp_storage, order_by, lag_lead_type, ignore_nulls),
2190            AggregateFunc::FirstValue {
2191                order_by,
2192                window_frame,
2193            } => first_value(datums, temp_storage, order_by, window_frame),
2194            AggregateFunc::LastValue {
2195                order_by,
2196                window_frame,
2197            } => last_value(datums, temp_storage, order_by, window_frame),
2198            AggregateFunc::WindowAggregate {
2199                wrapped_aggregate,
2200                order_by,
2201                window_frame,
2202            } => window_aggr::<_, NaiveOneByOneAggr>(
2203                datums,
2204                temp_storage,
2205                wrapped_aggregate,
2206                order_by,
2207                window_frame,
2208            ),
2209            AggregateFunc::FusedValueWindowFunc { funcs, order_by } => {
2210                fused_value_window_func(datums, temp_storage, funcs, order_by)
2211            }
2212            AggregateFunc::FusedWindowAggregate {
2213                wrapped_aggregates,
2214                order_by,
2215                window_frame,
2216            } => fused_window_aggr::<_, NaiveOneByOneAggr>(
2217                datums,
2218                temp_storage,
2219                wrapped_aggregates,
2220                order_by,
2221                window_frame,
2222            ),
2223            AggregateFunc::Dummy => Datum::Dummy,
2224        }
2225    }
2226
2227    /// Like `eval`, but it's given a [OneByOneAggr]. If `self` is a `WindowAggregate`, then
2228    /// the given [OneByOneAggr] will be used to evaluate the wrapped aggregate inside the
2229    /// `WindowAggregate`. If `self` is not a `WindowAggregate`, then it simply calls `eval`.
2230    pub fn eval_with_fast_window_agg<'a, I, W>(
2231        &self,
2232        datums: I,
2233        temp_storage: &'a RowArena,
2234    ) -> Datum<'a>
2235    where
2236        I: IntoIterator<Item = (Datum<'a>, Diff)>,
2237        W: OneByOneAggr,
2238    {
2239        match self {
2240            AggregateFunc::WindowAggregate {
2241                wrapped_aggregate,
2242                order_by,
2243                window_frame,
2244            } => window_aggr::<_, W>(
2245                expand_counts(datums),
2246                temp_storage,
2247                wrapped_aggregate,
2248                order_by,
2249                window_frame,
2250            ),
2251            AggregateFunc::FusedWindowAggregate {
2252                wrapped_aggregates,
2253                order_by,
2254                window_frame,
2255            } => fused_window_aggr::<_, W>(
2256                expand_counts(datums),
2257                temp_storage,
2258                wrapped_aggregates,
2259                order_by,
2260                window_frame,
2261            ),
2262            _ => self.eval(datums, temp_storage),
2263        }
2264    }
2265
2266    pub fn eval_with_unnest_list<'a, I, W>(
2267        &self,
2268        datums: I,
2269        temp_storage: &'a RowArena,
2270    ) -> impl Iterator<Item = Datum<'a>>
2271    where
2272        I: IntoIterator<Item = (Datum<'a>, Diff)>,
2273        W: OneByOneAggr,
2274    {
2275        // TODO: Use `enum_dispatch` to construct a unified iterator instead of `collect_vec`.
2276        assert!(self.can_fuse_with_unnest_list());
2277        // Window functions are sensitive to multiplicity, so expand counts.
2278        let datums = expand_counts(datums);
2279        match self {
2280            AggregateFunc::RowNumber { order_by } => {
2281                row_number_no_list(datums, temp_storage, order_by).collect_vec()
2282            }
2283            AggregateFunc::Rank { order_by } => {
2284                rank_no_list(datums, temp_storage, order_by).collect_vec()
2285            }
2286            AggregateFunc::DenseRank { order_by } => {
2287                dense_rank_no_list(datums, temp_storage, order_by).collect_vec()
2288            }
2289            AggregateFunc::LagLead {
2290                order_by,
2291                lag_lead: lag_lead_type,
2292                ignore_nulls,
2293            } => lag_lead_no_list(datums, temp_storage, order_by, lag_lead_type, ignore_nulls)
2294                .collect_vec(),
2295            AggregateFunc::FirstValue {
2296                order_by,
2297                window_frame,
2298            } => first_value_no_list(datums, temp_storage, order_by, window_frame).collect_vec(),
2299            AggregateFunc::LastValue {
2300                order_by,
2301                window_frame,
2302            } => last_value_no_list(datums, temp_storage, order_by, window_frame).collect_vec(),
2303            AggregateFunc::FusedValueWindowFunc { funcs, order_by } => {
2304                fused_value_window_func_no_list(datums, temp_storage, funcs, order_by).collect_vec()
2305            }
2306            AggregateFunc::WindowAggregate {
2307                wrapped_aggregate,
2308                order_by,
2309                window_frame,
2310            } => window_aggr_no_list::<_, W>(
2311                datums,
2312                temp_storage,
2313                wrapped_aggregate,
2314                order_by,
2315                window_frame,
2316            )
2317            .collect_vec(),
2318            AggregateFunc::FusedWindowAggregate {
2319                wrapped_aggregates,
2320                order_by,
2321                window_frame,
2322            } => fused_window_aggr_no_list::<_, W>(
2323                datums,
2324                temp_storage,
2325                wrapped_aggregates,
2326                order_by,
2327                window_frame,
2328            )
2329            .collect_vec(),
2330            _ => unreachable!("asserted above that `can_fuse_with_unnest_list`"),
2331        }
2332        .into_iter()
2333    }
2334
2335    /// Returns the output of the aggregation function when applied on an empty
2336    /// input relation.
2337    pub fn default(&self) -> Datum<'static> {
2338        match self {
2339            AggregateFunc::Count => Datum::Int64(0),
2340            AggregateFunc::Any => Datum::False,
2341            AggregateFunc::All => Datum::True,
2342            AggregateFunc::Dummy => Datum::Dummy,
2343            _ => Datum::Null,
2344        }
2345    }
2346
2347    /// Returns a datum whose inclusion in the aggregation will not change its
2348    /// result.
2349    pub fn identity_datum(&self) -> Datum<'static> {
2350        match self {
2351            AggregateFunc::Any => Datum::False,
2352            AggregateFunc::All => Datum::True,
2353            AggregateFunc::Dummy => Datum::Dummy,
2354            AggregateFunc::ArrayConcat { .. } => Datum::empty_array(),
2355            AggregateFunc::ListConcat { .. } => Datum::empty_list(),
2356            AggregateFunc::RowNumber { .. }
2357            | AggregateFunc::Rank { .. }
2358            | AggregateFunc::DenseRank { .. }
2359            | AggregateFunc::LagLead { .. }
2360            | AggregateFunc::FirstValue { .. }
2361            | AggregateFunc::LastValue { .. }
2362            | AggregateFunc::WindowAggregate { .. }
2363            | AggregateFunc::FusedValueWindowFunc { .. }
2364            | AggregateFunc::FusedWindowAggregate { .. } => Datum::empty_list(),
2365            AggregateFunc::MaxNumeric
2366            | AggregateFunc::MaxInt16
2367            | AggregateFunc::MaxInt32
2368            | AggregateFunc::MaxInt64
2369            | AggregateFunc::MaxUInt16
2370            | AggregateFunc::MaxUInt32
2371            | AggregateFunc::MaxUInt64
2372            | AggregateFunc::MaxMzTimestamp
2373            | AggregateFunc::MaxFloat32
2374            | AggregateFunc::MaxFloat64
2375            | AggregateFunc::MaxBool
2376            | AggregateFunc::MaxString
2377            | AggregateFunc::MaxDate
2378            | AggregateFunc::MaxTimestamp
2379            | AggregateFunc::MaxTimestampTz
2380            | AggregateFunc::MaxInterval
2381            | AggregateFunc::MaxTime
2382            | AggregateFunc::MinNumeric
2383            | AggregateFunc::MinInt16
2384            | AggregateFunc::MinInt32
2385            | AggregateFunc::MinInt64
2386            | AggregateFunc::MinUInt16
2387            | AggregateFunc::MinUInt32
2388            | AggregateFunc::MinUInt64
2389            | AggregateFunc::MinMzTimestamp
2390            | AggregateFunc::MinFloat32
2391            | AggregateFunc::MinFloat64
2392            | AggregateFunc::MinBool
2393            | AggregateFunc::MinString
2394            | AggregateFunc::MinDate
2395            | AggregateFunc::MinTimestamp
2396            | AggregateFunc::MinTimestampTz
2397            | AggregateFunc::MinInterval
2398            | AggregateFunc::MinTime
2399            | AggregateFunc::SumInt16
2400            | AggregateFunc::SumInt32
2401            | AggregateFunc::SumInt64
2402            | AggregateFunc::SumUInt16
2403            | AggregateFunc::SumUInt32
2404            | AggregateFunc::SumUInt64
2405            | AggregateFunc::SumFloat32
2406            | AggregateFunc::SumFloat64
2407            | AggregateFunc::SumNumeric
2408            | AggregateFunc::Count
2409            | AggregateFunc::JsonbAgg { .. }
2410            | AggregateFunc::JsonbObjectAgg { .. }
2411            | AggregateFunc::MapAgg { .. }
2412            | AggregateFunc::StringAgg { .. } => Datum::Null,
2413        }
2414    }
2415
2416    pub fn can_fuse_with_unnest_list(&self) -> bool {
2417        match self {
2418            AggregateFunc::RowNumber { .. }
2419            | AggregateFunc::Rank { .. }
2420            | AggregateFunc::DenseRank { .. }
2421            | AggregateFunc::LagLead { .. }
2422            | AggregateFunc::FirstValue { .. }
2423            | AggregateFunc::LastValue { .. }
2424            | AggregateFunc::WindowAggregate { .. }
2425            | AggregateFunc::FusedValueWindowFunc { .. }
2426            | AggregateFunc::FusedWindowAggregate { .. } => true,
2427            AggregateFunc::ArrayConcat { .. }
2428            | AggregateFunc::ListConcat { .. }
2429            | AggregateFunc::Any
2430            | AggregateFunc::All
2431            | AggregateFunc::Dummy
2432            | AggregateFunc::MaxNumeric
2433            | AggregateFunc::MaxInt16
2434            | AggregateFunc::MaxInt32
2435            | AggregateFunc::MaxInt64
2436            | AggregateFunc::MaxUInt16
2437            | AggregateFunc::MaxUInt32
2438            | AggregateFunc::MaxUInt64
2439            | AggregateFunc::MaxMzTimestamp
2440            | AggregateFunc::MaxFloat32
2441            | AggregateFunc::MaxFloat64
2442            | AggregateFunc::MaxBool
2443            | AggregateFunc::MaxString
2444            | AggregateFunc::MaxDate
2445            | AggregateFunc::MaxTimestamp
2446            | AggregateFunc::MaxTimestampTz
2447            | AggregateFunc::MaxInterval
2448            | AggregateFunc::MaxTime
2449            | AggregateFunc::MinNumeric
2450            | AggregateFunc::MinInt16
2451            | AggregateFunc::MinInt32
2452            | AggregateFunc::MinInt64
2453            | AggregateFunc::MinUInt16
2454            | AggregateFunc::MinUInt32
2455            | AggregateFunc::MinUInt64
2456            | AggregateFunc::MinMzTimestamp
2457            | AggregateFunc::MinFloat32
2458            | AggregateFunc::MinFloat64
2459            | AggregateFunc::MinBool
2460            | AggregateFunc::MinString
2461            | AggregateFunc::MinDate
2462            | AggregateFunc::MinTimestamp
2463            | AggregateFunc::MinTimestampTz
2464            | AggregateFunc::MinInterval
2465            | AggregateFunc::MinTime
2466            | AggregateFunc::SumInt16
2467            | AggregateFunc::SumInt32
2468            | AggregateFunc::SumInt64
2469            | AggregateFunc::SumUInt16
2470            | AggregateFunc::SumUInt32
2471            | AggregateFunc::SumUInt64
2472            | AggregateFunc::SumFloat32
2473            | AggregateFunc::SumFloat64
2474            | AggregateFunc::SumNumeric
2475            | AggregateFunc::Count
2476            | AggregateFunc::JsonbAgg { .. }
2477            | AggregateFunc::JsonbObjectAgg { .. }
2478            | AggregateFunc::MapAgg { .. }
2479            | AggregateFunc::StringAgg { .. } => false,
2480        }
2481    }
2482
2483    /// The output column type for the result of an aggregation.
2484    ///
2485    /// The output column type also contains nullability information, which
2486    /// is (without further information) true for aggregations that are not
2487    /// counts.
2488    pub fn output_sql_type(&self, input_type: SqlColumnType) -> SqlColumnType {
2489        let scalar_type = match self {
2490            AggregateFunc::Count => SqlScalarType::Int64,
2491            AggregateFunc::Any => SqlScalarType::Bool,
2492            AggregateFunc::All => SqlScalarType::Bool,
2493            AggregateFunc::JsonbAgg { .. } => SqlScalarType::Jsonb,
2494            AggregateFunc::JsonbObjectAgg { .. } => SqlScalarType::Jsonb,
2495            AggregateFunc::SumInt16 => SqlScalarType::Int64,
2496            AggregateFunc::SumInt32 => SqlScalarType::Int64,
2497            AggregateFunc::SumInt64 => SqlScalarType::Numeric {
2498                max_scale: Some(NumericMaxScale::ZERO),
2499            },
2500            AggregateFunc::SumUInt16 => SqlScalarType::UInt64,
2501            AggregateFunc::SumUInt32 => SqlScalarType::UInt64,
2502            AggregateFunc::SumUInt64 => SqlScalarType::Numeric {
2503                max_scale: Some(NumericMaxScale::ZERO),
2504            },
2505            AggregateFunc::MapAgg { value_type, .. } => SqlScalarType::Map {
2506                value_type: Box::new(value_type.clone()),
2507                custom_id: None,
2508            },
2509            AggregateFunc::ArrayConcat { .. } | AggregateFunc::ListConcat { .. } => {
2510                match input_type.scalar_type {
2511                    // The input is wrapped in a Record if there's an ORDER BY, so extract it out.
2512                    SqlScalarType::Record { ref fields, .. } => fields[0].1.scalar_type.clone(),
2513                    _ => unreachable!(),
2514                }
2515            }
2516            AggregateFunc::StringAgg { .. } => SqlScalarType::String,
2517            AggregateFunc::RowNumber { .. } => {
2518                AggregateFunc::output_type_ranking_window_funcs(&input_type, "?row_number?")
2519            }
2520            AggregateFunc::Rank { .. } => {
2521                AggregateFunc::output_type_ranking_window_funcs(&input_type, "?rank?")
2522            }
2523            AggregateFunc::DenseRank { .. } => {
2524                AggregateFunc::output_type_ranking_window_funcs(&input_type, "?dense_rank?")
2525            }
2526            AggregateFunc::LagLead { lag_lead: lag_lead_type, .. } => {
2527                // The input type for Lag is ((OriginalRow, EncodedArgs), OrderByExprs...)
2528                let fields = input_type.scalar_type.unwrap_record_element_type();
2529                let original_row_type = fields[0].unwrap_record_element_type()[0]
2530                    .clone()
2531                    .nullable(false);
2532                let encoded_args = fields[0].unwrap_record_element_type()[1];
2533                let output_type_inner =
2534                    Self::lag_lead_output_type_inner_from_encoded_args(encoded_args);
2535                let column_name = Self::lag_lead_result_column_name(lag_lead_type);
2536
2537                SqlScalarType::List {
2538                    element_type: Box::new(SqlScalarType::Record {
2539                        fields: [
2540                            (column_name, output_type_inner),
2541                            (ColumnName::from("?orig_row?"), original_row_type),
2542                        ].into(),
2543                        custom_id: None,
2544                    }),
2545                    custom_id: None,
2546                }
2547            }
2548            AggregateFunc::FirstValue { .. } => {
2549                // The input type for FirstValue is ((OriginalRow, Arg), OrderByExprs...)
2550                let fields = input_type.scalar_type.unwrap_record_element_type();
2551                let original_row_type = fields[0].unwrap_record_element_type()[0]
2552                    .clone()
2553                    .nullable(false);
2554                let value_type = fields[0].unwrap_record_element_type()[1]
2555                    .clone()
2556                    .nullable(true); // null when the partition is empty
2557
2558                SqlScalarType::List {
2559                    element_type: Box::new(SqlScalarType::Record {
2560                        fields: [
2561                            (ColumnName::from("?first_value?"), value_type),
2562                            (ColumnName::from("?orig_row?"), original_row_type),
2563                        ].into(),
2564                        custom_id: None,
2565                    }),
2566                    custom_id: None,
2567                }
2568            }
2569            AggregateFunc::LastValue { .. } => {
2570                // The input type for LastValue is ((OriginalRow, Arg), OrderByExprs...)
2571                let fields = input_type.scalar_type.unwrap_record_element_type();
2572                let original_row_type = fields[0].unwrap_record_element_type()[0]
2573                    .clone()
2574                    .nullable(false);
2575                let value_type = fields[0].unwrap_record_element_type()[1]
2576                    .clone()
2577                    .nullable(true); // null when the partition is empty
2578
2579                SqlScalarType::List {
2580                    element_type: Box::new(SqlScalarType::Record {
2581                        fields: [
2582                            (ColumnName::from("?last_value?"), value_type),
2583                            (ColumnName::from("?orig_row?"), original_row_type),
2584                        ].into(),
2585                        custom_id: None,
2586                    }),
2587                    custom_id: None,
2588                }
2589            }
2590            AggregateFunc::WindowAggregate {
2591                wrapped_aggregate, ..
2592            } => {
2593                // The input type for a window aggregate is ((OriginalRow, Arg), OrderByExprs...)
2594                let fields = input_type.scalar_type.unwrap_record_element_type();
2595                let original_row_type = fields[0].unwrap_record_element_type()[0]
2596                    .clone()
2597                    .nullable(false);
2598                let arg_type = fields[0].unwrap_record_element_type()[1]
2599                    .clone()
2600                    .nullable(true);
2601                let wrapped_aggr_out_type = wrapped_aggregate.output_sql_type(arg_type);
2602
2603                SqlScalarType::List {
2604                    element_type: Box::new(SqlScalarType::Record {
2605                        fields: [
2606                            (ColumnName::from("?window_agg?"), wrapped_aggr_out_type),
2607                            (ColumnName::from("?orig_row?"), original_row_type),
2608                        ].into(),
2609                        custom_id: None,
2610                    }),
2611                    custom_id: None,
2612                }
2613            }
2614            AggregateFunc::FusedWindowAggregate {
2615                wrapped_aggregates, ..
2616            } => {
2617                // The input type for a fused window aggregate is ((OriginalRow, Args), OrderByExprs...)
2618                // where `Args` is a record.
2619                let fields = input_type.scalar_type.unwrap_record_element_type();
2620                let original_row_type = fields[0].unwrap_record_element_type()[0]
2621                    .clone()
2622                    .nullable(false);
2623                let args_type = fields[0].unwrap_record_element_type()[1];
2624                let arg_types = args_type.unwrap_record_element_type();
2625                let out_fields = arg_types.iter().zip_eq(wrapped_aggregates).map(
2626                    |(arg_type, wrapped_agg)| {
2627                    (
2628                        ColumnName::from(wrapped_agg.name()),
2629                        wrapped_agg.output_sql_type((**arg_type).clone().nullable(true)),
2630                    )
2631                }).collect_vec();
2632
2633                SqlScalarType::List {
2634                    element_type: Box::new(SqlScalarType::Record {
2635                        fields: [
2636                            (ColumnName::from("?fused_window_agg?"), SqlScalarType::Record {
2637                                fields: out_fields.into(),
2638                                custom_id: None,
2639                            }.nullable(false)),
2640                            (ColumnName::from("?orig_row?"), original_row_type),
2641                        ].into(),
2642                        custom_id: None,
2643                    }),
2644                    custom_id: None,
2645                }
2646            }
2647            AggregateFunc::FusedValueWindowFunc { funcs, order_by: _ } => {
2648                // The input type is ((OriginalRow, EncodedArgs), OrderByExprs...)
2649                // where EncodedArgs is a record, where each element is the argument to one of the
2650                // function calls that got fused. This is a record for lag/lead, and a simple type
2651                // for first_value/last_value.
2652                let fields = input_type.scalar_type.unwrap_record_element_type();
2653                let original_row_type = fields[0].unwrap_record_element_type()[0]
2654                    .clone()
2655                    .nullable(false);
2656                let encoded_args_type = fields[0]
2657                    .unwrap_record_element_type()[1]
2658                    .unwrap_record_element_type();
2659
2660                SqlScalarType::List {
2661                    element_type: Box::new(SqlScalarType::Record {
2662                        fields: [
2663                            (
2664                                ColumnName::from("?fused_value_window_func?"),
2665                                SqlScalarType::Record {
2666                                fields: encoded_args_type.into_iter().zip_eq(funcs).map(
2667                                    |(arg_type, func)| {
2668                                    match func {
2669                                        AggregateFunc::LagLead {
2670                                            lag_lead: lag_lead_type, ..
2671                                        } => {
2672                                            let name = Self::lag_lead_result_column_name(
2673                                                lag_lead_type,
2674                                            );
2675                                            let ty = Self
2676                                                ::lag_lead_output_type_inner_from_encoded_args(
2677                                                    arg_type,
2678                                                );
2679                                            (name, ty)
2680                                        },
2681                                        AggregateFunc::FirstValue { .. } => {
2682                                            (
2683                                                ColumnName::from("?first_value?"),
2684                                                arg_type.clone().nullable(true),
2685                                            )
2686                                        }
2687                                        AggregateFunc::LastValue { .. } => {
2688                                            (
2689                                                ColumnName::from("?last_value?"),
2690                                                arg_type.clone().nullable(true),
2691                                            )
2692                                        }
2693                                        _ => panic!("FusedValueWindowFunc has an unknown function"),
2694                                    }
2695                                }).collect(),
2696                                custom_id: None,
2697                            }.nullable(false)),
2698                            (ColumnName::from("?orig_row?"), original_row_type),
2699                        ].into(),
2700                        custom_id: None,
2701                    }),
2702                    custom_id: None,
2703                }
2704            }
2705            AggregateFunc::Dummy
2706            | AggregateFunc::MaxNumeric
2707            | AggregateFunc::MaxInt16
2708            | AggregateFunc::MaxInt32
2709            | AggregateFunc::MaxInt64
2710            | AggregateFunc::MaxUInt16
2711            | AggregateFunc::MaxUInt32
2712            | AggregateFunc::MaxUInt64
2713            | AggregateFunc::MaxMzTimestamp
2714            | AggregateFunc::MaxFloat32
2715            | AggregateFunc::MaxFloat64
2716            | AggregateFunc::MaxBool
2717            // Note AggregateFunc::MaxString, MinString rely on returning input
2718            // type as output type to support the proper return type for
2719            // character input.
2720            | AggregateFunc::MaxString
2721            | AggregateFunc::MaxDate
2722            | AggregateFunc::MaxTimestamp
2723            | AggregateFunc::MaxTimestampTz
2724            | AggregateFunc::MaxInterval
2725            | AggregateFunc::MaxTime
2726            | AggregateFunc::MinNumeric
2727            | AggregateFunc::MinInt16
2728            | AggregateFunc::MinInt32
2729            | AggregateFunc::MinInt64
2730            | AggregateFunc::MinUInt16
2731            | AggregateFunc::MinUInt32
2732            | AggregateFunc::MinUInt64
2733            | AggregateFunc::MinMzTimestamp
2734            | AggregateFunc::MinFloat32
2735            | AggregateFunc::MinFloat64
2736            | AggregateFunc::MinBool
2737            | AggregateFunc::MinString
2738            | AggregateFunc::MinDate
2739            | AggregateFunc::MinTimestamp
2740            | AggregateFunc::MinTimestampTz
2741            | AggregateFunc::MinInterval
2742            | AggregateFunc::MinTime
2743            | AggregateFunc::SumFloat32
2744            | AggregateFunc::SumFloat64
2745            | AggregateFunc::SumNumeric => input_type.scalar_type.clone(),
2746        };
2747        // Count never produces null, and other aggregations only produce
2748        // null in the presence of null inputs.
2749        let nullable = match self {
2750            AggregateFunc::Count => false,
2751            // Use the nullability of the underlying column being aggregated, not the Records wrapping it
2752            AggregateFunc::StringAgg { .. } => match input_type.scalar_type {
2753                // The outer Record wraps the input in the first position, and any ORDER BY expressions afterwards
2754                SqlScalarType::Record { fields, .. } => match &fields[0].1.scalar_type {
2755                    // The inner Record is a (value, separator) tuple
2756                    SqlScalarType::Record { fields, .. } => fields[0].1.nullable,
2757                    _ => unreachable!(),
2758                },
2759                _ => unreachable!(),
2760            },
2761            _ => input_type.nullable,
2762        };
2763        scalar_type.nullable(nullable)
2764    }
2765
2766    /// Computes the representation type of this aggregate function.
2767    ///
2768    /// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type.
2769    pub fn output_type(&self, input_type: ReprColumnType) -> ReprColumnType {
2770        ReprColumnType::from(&self.output_sql_type(SqlColumnType::from_repr(&input_type)))
2771    }
2772
2773    /// Compute output type for ROW_NUMBER, RANK, DENSE_RANK
2774    fn output_type_ranking_window_funcs(
2775        input_type: &SqlColumnType,
2776        col_name: &str,
2777    ) -> SqlScalarType {
2778        match input_type.scalar_type {
2779            SqlScalarType::Record { ref fields, .. } => SqlScalarType::List {
2780                element_type: Box::new(SqlScalarType::Record {
2781                    fields: [
2782                        (
2783                            ColumnName::from(col_name),
2784                            SqlScalarType::Int64.nullable(false),
2785                        ),
2786                        (ColumnName::from("?orig_row?"), {
2787                            let inner = match &fields[0].1.scalar_type {
2788                                SqlScalarType::List { element_type, .. } => element_type.clone(),
2789                                _ => unreachable!(),
2790                            };
2791                            inner.nullable(false)
2792                        }),
2793                    ]
2794                    .into(),
2795                    custom_id: None,
2796                }),
2797                custom_id: None,
2798            },
2799            _ => unreachable!(),
2800        }
2801    }
2802
2803    /// Given the `EncodedArgs` part of `((OriginalRow, EncodedArgs), OrderByExprs...)`,
2804    /// this computes the type of the first field of the output type. (The first field is the
2805    /// real result, the rest is the original row.)
2806    fn lag_lead_output_type_inner_from_encoded_args(
2807        encoded_args_type: &SqlScalarType,
2808    ) -> SqlColumnType {
2809        // lag/lead have 3 arguments, and the output type is
2810        // the same as the first of these, but always nullable. (It's null when the
2811        // lag/lead computation reaches over the bounds of the window partition.)
2812        encoded_args_type.unwrap_record_element_type()[0]
2813            .clone()
2814            .nullable(true)
2815    }
2816
2817    fn lag_lead_result_column_name(lag_lead_type: &LagLeadType) -> ColumnName {
2818        ColumnName::from(match lag_lead_type {
2819            LagLeadType::Lag => "?lag?",
2820            LagLeadType::Lead => "?lead?",
2821        })
2822    }
2823
2824    /// Returns true if the non-null constraint on the aggregation can be
2825    /// converted into a non-null constraint on its parameter expression, ie.
2826    /// whether the result of the aggregation is null if all the input values
2827    /// are null.
2828    pub fn propagates_nonnull_constraint(&self) -> bool {
2829        match self {
2830            AggregateFunc::MaxNumeric
2831            | AggregateFunc::MaxInt16
2832            | AggregateFunc::MaxInt32
2833            | AggregateFunc::MaxInt64
2834            | AggregateFunc::MaxUInt16
2835            | AggregateFunc::MaxUInt32
2836            | AggregateFunc::MaxUInt64
2837            | AggregateFunc::MaxMzTimestamp
2838            | AggregateFunc::MaxFloat32
2839            | AggregateFunc::MaxFloat64
2840            | AggregateFunc::MaxBool
2841            | AggregateFunc::MaxString
2842            | AggregateFunc::MaxDate
2843            | AggregateFunc::MaxTimestamp
2844            | AggregateFunc::MaxTimestampTz
2845            | AggregateFunc::MaxInterval
2846            | AggregateFunc::MaxTime
2847            | AggregateFunc::MinNumeric
2848            | AggregateFunc::MinInt16
2849            | AggregateFunc::MinInt32
2850            | AggregateFunc::MinInt64
2851            | AggregateFunc::MinUInt16
2852            | AggregateFunc::MinUInt32
2853            | AggregateFunc::MinUInt64
2854            | AggregateFunc::MinMzTimestamp
2855            | AggregateFunc::MinFloat32
2856            | AggregateFunc::MinFloat64
2857            | AggregateFunc::MinBool
2858            | AggregateFunc::MinString
2859            | AggregateFunc::MinDate
2860            | AggregateFunc::MinTimestamp
2861            | AggregateFunc::MinTimestampTz
2862            | AggregateFunc::MinInterval
2863            | AggregateFunc::MinTime
2864            | AggregateFunc::SumInt16
2865            | AggregateFunc::SumInt32
2866            | AggregateFunc::SumInt64
2867            | AggregateFunc::SumUInt16
2868            | AggregateFunc::SumUInt32
2869            | AggregateFunc::SumUInt64
2870            | AggregateFunc::SumFloat32
2871            | AggregateFunc::SumFloat64
2872            | AggregateFunc::SumNumeric
2873            | AggregateFunc::StringAgg { .. } => true,
2874            // Count is never null
2875            AggregateFunc::Count
2876            | AggregateFunc::Any
2877            | AggregateFunc::All
2878            | AggregateFunc::JsonbAgg { .. }
2879            | AggregateFunc::JsonbObjectAgg { .. }
2880            | AggregateFunc::MapAgg { .. }
2881            | AggregateFunc::ArrayConcat { .. }
2882            | AggregateFunc::ListConcat { .. }
2883            | AggregateFunc::RowNumber { .. }
2884            | AggregateFunc::Rank { .. }
2885            | AggregateFunc::DenseRank { .. }
2886            | AggregateFunc::LagLead { .. }
2887            | AggregateFunc::FirstValue { .. }
2888            | AggregateFunc::LastValue { .. }
2889            | AggregateFunc::FusedValueWindowFunc { .. }
2890            | AggregateFunc::WindowAggregate { .. }
2891            | AggregateFunc::FusedWindowAggregate { .. }
2892            | AggregateFunc::Dummy => false,
2893        }
2894    }
2895}
2896
2897fn jsonb_each<'a>(a: Datum<'a>) -> impl Iterator<Item = (Row, Diff)> + 'a {
2898    // First produce a map, so that a common iterator can be returned.
2899    let map = match a {
2900        Datum::Map(dict) => dict,
2901        _ => mz_repr::DatumMap::empty(),
2902    };
2903
2904    map.iter()
2905        .map(move |(k, v)| (Row::pack_slice(&[Datum::String(k), v]), Diff::ONE))
2906}
2907
2908fn jsonb_each_stringify<'a>(
2909    a: Datum<'a>,
2910    temp_storage: &'a RowArena,
2911) -> impl Iterator<Item = (Row, Diff)> + 'a {
2912    // First produce a map, so that a common iterator can be returned.
2913    let map = match a {
2914        Datum::Map(dict) => dict,
2915        _ => mz_repr::DatumMap::empty(),
2916    };
2917
2918    map.iter().map(move |(k, mut v)| {
2919        v = jsonb_stringify(v, temp_storage)
2920            .map(Datum::String)
2921            .unwrap_or(Datum::Null);
2922        (Row::pack_slice(&[Datum::String(k), v]), Diff::ONE)
2923    })
2924}
2925
2926fn jsonb_object_keys<'a>(a: Datum<'a>) -> impl Iterator<Item = (Row, Diff)> + 'a {
2927    let map = match a {
2928        Datum::Map(dict) => dict,
2929        _ => mz_repr::DatumMap::empty(),
2930    };
2931
2932    map.iter()
2933        .map(move |(k, _)| (Row::pack_slice(&[Datum::String(k)]), Diff::ONE))
2934}
2935
2936fn jsonb_array_elements<'a>(a: Datum<'a>) -> impl Iterator<Item = (Row, Diff)> + 'a {
2937    let list = match a {
2938        Datum::List(list) => list,
2939        _ => mz_repr::DatumList::empty(),
2940    };
2941    list.iter().map(move |e| (Row::pack_slice(&[e]), Diff::ONE))
2942}
2943
2944fn jsonb_array_elements_stringify<'a>(
2945    a: Datum<'a>,
2946    temp_storage: &'a RowArena,
2947) -> impl Iterator<Item = (Row, Diff)> + 'a {
2948    let list = match a {
2949        Datum::List(list) => list,
2950        _ => mz_repr::DatumList::empty(),
2951    };
2952    list.iter().map(move |mut e| {
2953        e = jsonb_stringify(e, temp_storage)
2954            .map(Datum::String)
2955            .unwrap_or(Datum::Null);
2956        (Row::pack_slice(&[e]), Diff::ONE)
2957    })
2958}
2959
2960fn regexp_extract(a: Datum, r: &AnalyzedRegex) -> Option<(Row, Diff)> {
2961    let r = r.inner();
2962    let a = a.unwrap_str();
2963    let captures = r.captures(a)?;
2964    let datums = captures
2965        .iter()
2966        .skip(1)
2967        .map(|m| Datum::from(m.map(|m| m.as_str())));
2968    Some((Row::pack(datums), Diff::ONE))
2969}
2970
2971fn regexp_matches<'a>(
2972    exprs: &[Datum<'a>],
2973) -> Result<impl Iterator<Item = (Row, Diff)> + 'a, EvalError> {
2974    // There are only two acceptable ways to call this function:
2975    // 1. regexp_matches(string, regex)
2976    // 2. regexp_matches(string, regex, flag)
2977    assert!(exprs.len() == 2 || exprs.len() == 3);
2978    let a = exprs[0].unwrap_str();
2979    let r = exprs[1].unwrap_str();
2980
2981    let (regex, opts) = if exprs.len() == 3 {
2982        let flag = exprs[2].unwrap_str();
2983        let opts = AnalyzedRegexOpts::from_str(flag)?;
2984        (AnalyzedRegex::new(r, opts)?, opts)
2985    } else {
2986        let opts = AnalyzedRegexOpts::default();
2987        (AnalyzedRegex::new(r, opts)?, opts)
2988    };
2989
2990    let regex = regex.inner().clone();
2991
2992    let iter = regex.captures_iter(a).map(move |captures| {
2993        let matches = captures
2994            .iter()
2995            // The first match is the *entire* match, we want the capture groups by themselves.
2996            .skip(1)
2997            .map(|m| Datum::from(m.map(|m| m.as_str())))
2998            .collect::<Vec<_>>();
2999
3000        let mut binding = SharedRow::get();
3001        let mut packer = binding.packer();
3002
3003        let dimension = ArrayDimension {
3004            lower_bound: 1,
3005            length: matches.len(),
3006        };
3007        packer
3008            .try_push_array(&[dimension], matches)
3009            .expect("generated dimensions above");
3010
3011        (binding.clone(), Diff::ONE)
3012    });
3013
3014    // This is slightly unfortunate, but we need to collect the captures into a
3015    // Vec before we can yield them, because we can't return a iter with a
3016    // reference to the local `regex` variable.
3017    // We attempt to minimize the cost of this by using a SmallVec.
3018    let out = iter.collect::<SmallVec<[_; 3]>>();
3019
3020    if opts.global {
3021        Ok(Either::Left(out.into_iter()))
3022    } else {
3023        Ok(Either::Right(out.into_iter().take(1)))
3024    }
3025}
3026
3027fn generate_series<N>(
3028    start: N,
3029    stop: N,
3030    step: N,
3031) -> Result<impl Iterator<Item = (Row, Diff)>, EvalError>
3032where
3033    N: Integer + Signed + CheckedAdd + Clone,
3034    Datum<'static>: From<N>,
3035{
3036    if step == N::zero() {
3037        return Err(EvalError::InvalidParameterValue(
3038            "step size cannot equal zero".into(),
3039        ));
3040    }
3041    Ok(num::range_step_inclusive(start, stop, step)
3042        .map(move |i| (Row::pack_slice(&[Datum::from(i)]), Diff::ONE)))
3043}
3044
3045/// Like
3046/// [`num::range_step_inclusive`](https://github.com/rust-num/num-iter/blob/ddb14c1e796d401014c6c7a727de61d8109ad986/src/lib.rs#L279),
3047/// but for our timestamp types using [`Interval`] for `step`.xwxw
3048#[derive(Clone)]
3049pub struct TimestampRangeStepInclusive<T> {
3050    state: CheckedTimestamp<T>,
3051    stop: CheckedTimestamp<T>,
3052    step: Interval,
3053    rev: bool,
3054    done: bool,
3055}
3056
3057impl<T: TimestampLike> Iterator for TimestampRangeStepInclusive<T> {
3058    type Item = CheckedTimestamp<T>;
3059
3060    #[inline]
3061    fn next(&mut self) -> Option<CheckedTimestamp<T>> {
3062        if !self.done
3063            && ((self.rev && self.state >= self.stop) || (!self.rev && self.state <= self.stop))
3064        {
3065            let result = self.state.clone();
3066            match add_timestamp_months(self.state.deref(), self.step.months) {
3067                Ok(state) => match state.checked_add_signed(self.step.duration_as_chrono()) {
3068                    Some(v) => match CheckedTimestamp::from_timestamplike(v) {
3069                        Ok(v) => self.state = v,
3070                        Err(_) => self.done = true,
3071                    },
3072                    None => self.done = true,
3073                },
3074                Err(..) => {
3075                    self.done = true;
3076                }
3077            }
3078
3079            Some(result)
3080        } else {
3081            None
3082        }
3083    }
3084}
3085
3086fn generate_series_ts<T: TimestampLike>(
3087    start: CheckedTimestamp<T>,
3088    stop: CheckedTimestamp<T>,
3089    step: Interval,
3090    conv: fn(CheckedTimestamp<T>) -> Datum<'static>,
3091) -> Result<impl Iterator<Item = (Row, Diff)>, EvalError> {
3092    let normalized_step = step.as_microseconds();
3093    if normalized_step == 0 {
3094        return Err(EvalError::InvalidParameterValue(
3095            "step size cannot equal zero".into(),
3096        ));
3097    }
3098    let rev = normalized_step < 0;
3099
3100    let trsi = TimestampRangeStepInclusive {
3101        state: start,
3102        stop,
3103        step,
3104        rev,
3105        done: false,
3106    };
3107
3108    Ok(trsi.map(move |i| (Row::pack_slice(&[conv(i)]), Diff::ONE)))
3109}
3110
3111fn generate_subscripts_array(
3112    a: Datum,
3113    dim: i32,
3114) -> Result<Box<dyn Iterator<Item = (Row, Diff)>>, EvalError> {
3115    if dim <= 0 {
3116        return Ok(Box::new(iter::empty()));
3117    }
3118
3119    match a.unwrap_array().dims().into_iter().nth(
3120        (dim - 1)
3121            .try_into()
3122            .map_err(|_| EvalError::Int32OutOfRange((dim - 1).to_string().into()))?,
3123    ) {
3124        Some(requested_dim) => {
3125            let lower_bound: i32 = requested_dim.lower_bound.try_into().map_err(|_| {
3126                EvalError::Int32OutOfRange(requested_dim.lower_bound.to_string().into())
3127            })?;
3128            // The subscripts run from the lower bound to the upper bound,
3129            // inclusive. The upper bound is `lower_bound + length - 1`.
3130            let length: i32 = requested_dim
3131                .length
3132                .try_into()
3133                .map_err(|_| EvalError::Int32OutOfRange(requested_dim.length.to_string().into()))?;
3134            let upper_bound = lower_bound.checked_add(length - 1).ok_or_else(|| {
3135                EvalError::Int32OutOfRange(requested_dim.length.to_string().into())
3136            })?;
3137            Ok(Box::new(generate_series::<i32>(
3138                lower_bound,
3139                upper_bound,
3140                1,
3141            )?))
3142        }
3143        None => Ok(Box::new(iter::empty())),
3144    }
3145}
3146
3147fn unnest_array<'a>(a: Datum<'a>) -> impl Iterator<Item = (Row, Diff)> + 'a {
3148    a.unwrap_array()
3149        .elements()
3150        .iter()
3151        .map(move |e| (Row::pack_slice(&[e]), Diff::ONE))
3152}
3153
3154fn unnest_list<'a>(a: Datum<'a>) -> impl Iterator<Item = (Row, Diff)> + 'a {
3155    a.unwrap_list()
3156        .iter()
3157        .map(move |e| (Row::pack_slice(&[e]), Diff::ONE))
3158}
3159
3160fn unnest_map<'a>(a: Datum<'a>) -> impl Iterator<Item = (Row, Diff)> + 'a {
3161    a.unwrap_map()
3162        .iter()
3163        .map(move |(k, v)| (Row::pack_slice(&[Datum::from(k), v]), Diff::ONE))
3164}
3165
3166impl AggregateFunc {
3167    /// The base function name without the `~[...]` suffix used when rendering
3168    /// variants that represent a parameterized function family.
3169    pub fn name(&self) -> &'static str {
3170        match self {
3171            Self::MaxNumeric => "max",
3172            Self::MaxInt16 => "max",
3173            Self::MaxInt32 => "max",
3174            Self::MaxInt64 => "max",
3175            Self::MaxUInt16 => "max",
3176            Self::MaxUInt32 => "max",
3177            Self::MaxUInt64 => "max",
3178            Self::MaxMzTimestamp => "max",
3179            Self::MaxFloat32 => "max",
3180            Self::MaxFloat64 => "max",
3181            Self::MaxBool => "max",
3182            Self::MaxString => "max",
3183            Self::MaxDate => "max",
3184            Self::MaxTimestamp => "max",
3185            Self::MaxTimestampTz => "max",
3186            Self::MaxInterval => "max",
3187            Self::MaxTime => "max",
3188            Self::MinNumeric => "min",
3189            Self::MinInt16 => "min",
3190            Self::MinInt32 => "min",
3191            Self::MinInt64 => "min",
3192            Self::MinUInt16 => "min",
3193            Self::MinUInt32 => "min",
3194            Self::MinUInt64 => "min",
3195            Self::MinMzTimestamp => "min",
3196            Self::MinFloat32 => "min",
3197            Self::MinFloat64 => "min",
3198            Self::MinBool => "min",
3199            Self::MinString => "min",
3200            Self::MinDate => "min",
3201            Self::MinTimestamp => "min",
3202            Self::MinTimestampTz => "min",
3203            Self::MinInterval => "min",
3204            Self::MinTime => "min",
3205            Self::SumInt16 => "sum",
3206            Self::SumInt32 => "sum",
3207            Self::SumInt64 => "sum",
3208            Self::SumUInt16 => "sum",
3209            Self::SumUInt32 => "sum",
3210            Self::SumUInt64 => "sum",
3211            Self::SumFloat32 => "sum",
3212            Self::SumFloat64 => "sum",
3213            Self::SumNumeric => "sum",
3214            Self::Count => "count",
3215            Self::Any => "any",
3216            Self::All => "all",
3217            Self::JsonbAgg { .. } => "jsonb_agg",
3218            Self::JsonbObjectAgg { .. } => "jsonb_object_agg",
3219            Self::MapAgg { .. } => "map_agg",
3220            Self::ArrayConcat { .. } => "array_agg",
3221            Self::ListConcat { .. } => "list_agg",
3222            Self::StringAgg { .. } => "string_agg",
3223            Self::RowNumber { .. } => "row_number",
3224            Self::Rank { .. } => "rank",
3225            Self::DenseRank { .. } => "dense_rank",
3226            Self::LagLead {
3227                lag_lead: LagLeadType::Lag,
3228                ..
3229            } => "lag",
3230            Self::LagLead {
3231                lag_lead: LagLeadType::Lead,
3232                ..
3233            } => "lead",
3234            Self::FirstValue { .. } => "first_value",
3235            Self::LastValue { .. } => "last_value",
3236            Self::WindowAggregate { .. } => "window_agg",
3237            Self::FusedValueWindowFunc { .. } => "fused_value_window_func",
3238            Self::FusedWindowAggregate { .. } => "fused_window_agg",
3239            Self::Dummy => "dummy",
3240        }
3241    }
3242}
3243
3244impl<'a, M> fmt::Display for HumanizedExpr<'a, AggregateFunc, M>
3245where
3246    M: HumanizerMode,
3247{
3248    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3249        use AggregateFunc::*;
3250        let name = self.expr.name();
3251        match self.expr {
3252            JsonbAgg { order_by }
3253            | JsonbObjectAgg { order_by }
3254            | MapAgg { order_by, .. }
3255            | ArrayConcat { order_by }
3256            | ListConcat { order_by }
3257            | StringAgg { order_by }
3258            | RowNumber { order_by }
3259            | Rank { order_by }
3260            | DenseRank { order_by } => {
3261                let order_by = order_by.iter().map(|col| self.child(col));
3262                write!(f, "{}[order_by=[{}]]", name, separated(", ", order_by))
3263            }
3264            LagLead {
3265                lag_lead: _,
3266                ignore_nulls,
3267                order_by,
3268            } => {
3269                let order_by = order_by.iter().map(|col| self.child(col));
3270                f.write_str(name)?;
3271                f.write_str("[")?;
3272                if *ignore_nulls {
3273                    f.write_str("ignore_nulls=true, ")?;
3274                }
3275                write!(f, "order_by=[{}]", separated(", ", order_by))?;
3276                f.write_str("]")
3277            }
3278            FirstValue {
3279                order_by,
3280                window_frame,
3281            } => {
3282                let order_by = order_by.iter().map(|col| self.child(col));
3283                f.write_str(name)?;
3284                f.write_str("[")?;
3285                write!(f, "order_by=[{}]", separated(", ", order_by))?;
3286                if *window_frame != WindowFrame::default() {
3287                    write!(f, " {}", window_frame)?;
3288                }
3289                f.write_str("]")
3290            }
3291            LastValue {
3292                order_by,
3293                window_frame,
3294            } => {
3295                let order_by = order_by.iter().map(|col| self.child(col));
3296                f.write_str(name)?;
3297                f.write_str("[")?;
3298                write!(f, "order_by=[{}]", separated(", ", order_by))?;
3299                if *window_frame != WindowFrame::default() {
3300                    write!(f, " {}", window_frame)?;
3301                }
3302                f.write_str("]")
3303            }
3304            WindowAggregate {
3305                wrapped_aggregate,
3306                order_by,
3307                window_frame,
3308            } => {
3309                let order_by = order_by.iter().map(|col| self.child(col));
3310                let wrapped_aggregate = self.child(wrapped_aggregate.deref());
3311                f.write_str(name)?;
3312                f.write_str("[")?;
3313                write!(f, "{} ", wrapped_aggregate)?;
3314                write!(f, "order_by=[{}]", separated(", ", order_by))?;
3315                if *window_frame != WindowFrame::default() {
3316                    write!(f, " {}", window_frame)?;
3317                }
3318                f.write_str("]")
3319            }
3320            FusedValueWindowFunc { funcs, order_by } => {
3321                let order_by = order_by.iter().map(|col| self.child(col));
3322                let funcs = separated(", ", funcs.iter().map(|func| self.child(func)));
3323                f.write_str(name)?;
3324                f.write_str("[")?;
3325                write!(f, "{} ", funcs)?;
3326                write!(f, "order_by=[{}]", separated(", ", order_by))?;
3327                f.write_str("]")
3328            }
3329            _ => f.write_str(name),
3330        }
3331    }
3332}
3333
3334#[derive(
3335    Clone,
3336    Debug,
3337    Eq,
3338    PartialEq,
3339    Ord,
3340    PartialOrd,
3341    Serialize,
3342    Deserialize,
3343    Hash,
3344    MzReflect
3345)]
3346pub struct CaptureGroupDesc {
3347    pub index: u32,
3348    pub name: Option<String>,
3349    pub nullable: bool,
3350}
3351
3352#[derive(
3353    Clone,
3354    Copy,
3355    Debug,
3356    Eq,
3357    PartialEq,
3358    Ord,
3359    PartialOrd,
3360    Serialize,
3361    Deserialize,
3362    Hash,
3363    MzReflect,
3364    Default
3365)]
3366pub struct AnalyzedRegexOpts {
3367    pub case_insensitive: bool,
3368    pub global: bool,
3369}
3370
3371impl FromStr for AnalyzedRegexOpts {
3372    type Err = EvalError;
3373
3374    fn from_str(s: &str) -> Result<Self, Self::Err> {
3375        let mut opts = AnalyzedRegexOpts::default();
3376        for c in s.chars() {
3377            match c {
3378                'i' => opts.case_insensitive = true,
3379                'g' => opts.global = true,
3380                _ => return Err(EvalError::InvalidRegexFlag(c)),
3381            }
3382        }
3383        Ok(opts)
3384    }
3385}
3386
3387#[derive(
3388    Clone,
3389    Debug,
3390    Eq,
3391    PartialEq,
3392    Ord,
3393    PartialOrd,
3394    Serialize,
3395    Deserialize,
3396    Hash,
3397    MzReflect
3398)]
3399pub struct AnalyzedRegex(ReprRegex, Vec<CaptureGroupDesc>, AnalyzedRegexOpts);
3400
3401impl AnalyzedRegex {
3402    pub fn new(s: &str, opts: AnalyzedRegexOpts) -> Result<Self, RegexCompilationError> {
3403        let r = ReprRegex::new(s, opts.case_insensitive)?;
3404        // TODO(benesch): remove potentially dangerous usage of `as`.
3405        #[allow(clippy::as_conversions)]
3406        let descs: Vec<_> = r
3407            .capture_names()
3408            .enumerate()
3409            // The first capture is the entire matched string.
3410            // This will often not be useful, so skip it.
3411            // If people want it they can just surround their
3412            // entire regex in an explicit capture group.
3413            .skip(1)
3414            .map(|(i, name)| CaptureGroupDesc {
3415                index: i as u32,
3416                name: name.map(String::from),
3417                // TODO -- we can do better.
3418                // https://github.com/MaterializeInc/database-issues/issues/612
3419                nullable: true,
3420            })
3421            .collect();
3422        Ok(Self(r, descs, opts))
3423    }
3424    pub fn capture_groups_len(&self) -> usize {
3425        self.1.len()
3426    }
3427    pub fn capture_groups_iter(&self) -> impl Iterator<Item = &CaptureGroupDesc> {
3428        self.1.iter()
3429    }
3430    pub fn inner(&self) -> &Regex {
3431        &(self.0).regex
3432    }
3433    pub fn opts(&self) -> &AnalyzedRegexOpts {
3434        &self.2
3435    }
3436}
3437
3438pub fn csv_extract(a: Datum<'_>, n_cols: usize) -> impl Iterator<Item = (Row, Diff)> + '_ {
3439    let bytes = a.unwrap_str().as_bytes();
3440    let mut row = Row::default();
3441    let csv_reader = csv::ReaderBuilder::new()
3442        .has_headers(false)
3443        .from_reader(bytes);
3444    csv_reader.into_records().filter_map(move |res| match res {
3445        Ok(sr) if sr.len() == n_cols => {
3446            row.packer().extend(sr.iter().map(Datum::String));
3447            Some((row.clone(), Diff::ONE))
3448        }
3449        _ => None,
3450    })
3451}
3452
3453pub fn repeat_row(a: Datum) -> Option<(Row, Diff)> {
3454    let n = a.unwrap_int64();
3455    if n != 0 {
3456        Some((Row::default(), n.into()))
3457    } else {
3458        None
3459    }
3460}
3461
3462pub fn repeat_row_non_negative<'a>(
3463    a: Datum,
3464) -> Result<Box<dyn Iterator<Item = (Row, Diff)> + 'a>, EvalError> {
3465    let n = a.unwrap_int64();
3466    if n < 0 {
3467        Err(EvalError::InvalidParameterValue(
3468            format!("repeat_row_non_negative got {}", n).into(),
3469        ))
3470    } else if n == 0 {
3471        Ok(Box::new(iter::empty()))
3472    } else {
3473        // iterator with 1 element; n goes into the diff
3474        Ok(Box::new(iter::once((Row::default(), n.into()))))
3475    }
3476}
3477
3478fn wrap<'a>(datums: &'a [Datum<'a>], width: usize) -> impl Iterator<Item = (Row, Diff)> + 'a {
3479    datums
3480        .chunks(width)
3481        .map(|chunk| (Row::pack(chunk), Diff::ONE))
3482}
3483
3484fn acl_explode<'a>(
3485    acl_items: Datum<'a>,
3486    temp_storage: &'a RowArena,
3487) -> Result<impl Iterator<Item = (Row, Diff)> + 'a, EvalError> {
3488    let acl_items = acl_items.unwrap_array();
3489    let mut res = Vec::new();
3490    for acl_item in acl_items.elements().iter() {
3491        if acl_item.is_null() {
3492            return Err(EvalError::AclArrayNullElement);
3493        }
3494        let acl_item = acl_item.unwrap_acl_item();
3495        for privilege in acl_item.acl_mode.explode() {
3496            let row = [
3497                Datum::UInt32(acl_item.grantor.0),
3498                Datum::UInt32(acl_item.grantee.0),
3499                Datum::String(temp_storage.push_string(privilege.to_string())),
3500                // GRANT OPTION is not implemented, so we hardcode false.
3501                Datum::False,
3502            ];
3503            res.push((Row::pack_slice(&row), Diff::ONE));
3504        }
3505    }
3506    Ok(res.into_iter())
3507}
3508
3509fn mz_acl_explode<'a>(
3510    mz_acl_items: Datum<'a>,
3511    temp_storage: &'a RowArena,
3512) -> Result<impl Iterator<Item = (Row, Diff)> + 'a, EvalError> {
3513    let mz_acl_items = mz_acl_items.unwrap_array();
3514    let mut res = Vec::new();
3515    for mz_acl_item in mz_acl_items.elements().iter() {
3516        if mz_acl_item.is_null() {
3517            return Err(EvalError::MzAclArrayNullElement);
3518        }
3519        let mz_acl_item = mz_acl_item.unwrap_mz_acl_item();
3520        for privilege in mz_acl_item.acl_mode.explode() {
3521            let row = [
3522                Datum::String(temp_storage.push_string(mz_acl_item.grantor.to_string())),
3523                Datum::String(temp_storage.push_string(mz_acl_item.grantee.to_string())),
3524                Datum::String(temp_storage.push_string(privilege.to_string())),
3525                // GRANT OPTION is not implemented, so we hardcode false.
3526                Datum::False,
3527            ];
3528            res.push((Row::pack_slice(&row), Diff::ONE));
3529        }
3530    }
3531    Ok(res.into_iter())
3532}
3533
3534/// When adding a new `TableFunc` variant, please consider adding it to
3535/// `TableFunc::with_ordinality`!
3536#[derive(
3537    Clone,
3538    Debug,
3539    Eq,
3540    PartialEq,
3541    Ord,
3542    PartialOrd,
3543    Serialize,
3544    Deserialize,
3545    Hash,
3546    MzReflect
3547)]
3548pub enum TableFunc {
3549    AclExplode,
3550    MzAclExplode,
3551    JsonbEach,
3552    JsonbEachStringify,
3553    JsonbObjectKeys,
3554    JsonbArrayElements,
3555    JsonbArrayElementsStringify,
3556    RegexpExtract(AnalyzedRegex),
3557    CsvExtract(usize),
3558    GenerateSeriesInt32,
3559    GenerateSeriesInt64,
3560    /// An int64 `generate_series` that the optimizer promises to leave as an
3561    /// enumeration: no transform may match on this variant to replace its
3562    /// evaluation with a cardinality shortcut (compare the collapse of an
3563    /// unused `GenerateSeriesInt64` into `RepeatRowNonNegative`). Its
3564    /// *argument* expressions are still simplified like any other scalar.
3565    ///
3566    /// Exposed as `mz_unsafe.generate_series_unoptimized` for tests that rely
3567    /// on the work of enumeration actually happening (e.g. stress tests whose
3568    /// load would otherwise be optimized away). As with everything in
3569    /// `mz_unsafe`, it is not a supported surface: bug reports must not
3570    /// depend on it.
3571    GenerateSeriesUnoptimized,
3572    GenerateSeriesTimestamp,
3573    GenerateSeriesTimestampTz,
3574    /// Supplied with an input count,
3575    ///   1. Adds a column as if a typed subquery result,
3576    ///   2. Filters the row away if the count is only one,
3577    ///   3. Errors if the count is not exactly one.
3578    /// The intent is that this presents as if a subquery result with too many
3579    /// records contributing. The error column has the same type as the result
3580    /// should have, but we only produce it if the count exceeds one.
3581    ///
3582    /// This logic could nearly be achieved with map, filter, project logic,
3583    /// but has been challenging to do in a way that respects the vagaries of
3584    /// SQL and our semantics. If we reveal a constant value in the column we
3585    /// risk the optimizer pruning the branch; if we reveal that this will not
3586    /// produce rows we risk the optimizer pruning the branch; if we reveal that
3587    /// the only possible value is an error we risk the optimizer propagating that
3588    /// error without guards.
3589    ///
3590    /// Before replacing this by an `MirScalarExpr`, quadruple check that it
3591    /// would not result in misoptimizations due to expression evaluation order
3592    /// being utterly undefined, and predicate pushdown trimming any fragments
3593    /// that might produce columns that will not be needed.
3594    GuardSubquerySize {
3595        column_type: SqlScalarType,
3596    },
3597    /// Repeats the input row the given number of times. Can even repeat a negative number of times,
3598    /// which has some important consequences:
3599    /// - can lead to negative accumulations downstream;
3600    /// - can't be used in `WITH ORDINALITY` and other constructs that are implemented by
3601    ///   `TableFunc::WithOrdinality`, e.g., `ROWS FROM`;
3602    /// - output is non-monotonic.
3603    RepeatRow,
3604    /// Same as `RepeatRow`, but errors on a negative count, and thereby avoids the above
3605    /// peculiarities.
3606    RepeatRowNonNegative,
3607    UnnestArray {
3608        el_typ: SqlScalarType,
3609    },
3610    UnnestList {
3611        el_typ: SqlScalarType,
3612    },
3613    UnnestMap {
3614        value_type: SqlScalarType,
3615    },
3616    /// Given `n` input expressions, wraps them into `n / width` rows, each of
3617    /// `width` columns.
3618    ///
3619    /// This function is not intended to be called directly by end users, but
3620    /// is useful in the planning of e.g. VALUES clauses.
3621    Wrap {
3622        types: Vec<SqlColumnType>,
3623        width: usize,
3624    },
3625    GenerateSubscriptsArray,
3626    /// Execute some arbitrary scalar function as a table function.
3627    TabletizedScalar {
3628        name: String,
3629        relation: SqlRelationType,
3630    },
3631    RegexpMatches,
3632    /// Implements the WITH ORDINALITY clause.
3633    ///
3634    /// Don't construct `TableFunc::WithOrdinality` manually! Use the `with_ordinality` constructor
3635    /// function instead, which checks whether the given table function supports `WithOrdinality`.
3636    #[allow(private_interfaces)]
3637    WithOrdinality(WithOrdinality),
3638}
3639
3640/// Evaluates the inner table function, expands its results into unary (repeating each row as
3641/// many times as the diff indicates), and appends an integer corresponding to the ordinal
3642/// position (starting from 1). For example, it numbers the elements of a list when calling
3643/// `unnest_list`.
3644///
3645/// Private enum variant of `TableFunc`. Don't construct this directly, but use
3646/// `TableFunc::with_ordinality` instead.
3647#[derive(
3648    Clone,
3649    Debug,
3650    Eq,
3651    PartialEq,
3652    Ord,
3653    PartialOrd,
3654    Serialize,
3655    Deserialize,
3656    Hash,
3657    MzReflect
3658)]
3659struct WithOrdinality {
3660    inner: Box<TableFunc>,
3661}
3662
3663impl TableFunc {
3664    /// Adds `WITH ORDINALITY` to a table function if it's allowed on the given table function.
3665    pub fn with_ordinality(inner: TableFunc) -> Option<TableFunc> {
3666        match inner {
3667            TableFunc::AclExplode
3668            | TableFunc::MzAclExplode
3669            | TableFunc::JsonbEach
3670            | TableFunc::JsonbEachStringify
3671            | TableFunc::JsonbObjectKeys
3672            | TableFunc::JsonbArrayElements
3673            | TableFunc::JsonbArrayElementsStringify
3674            | TableFunc::RegexpExtract(_)
3675            | TableFunc::CsvExtract(_)
3676            | TableFunc::GenerateSeriesInt32
3677            | TableFunc::GenerateSeriesInt64
3678            | TableFunc::GenerateSeriesUnoptimized
3679            | TableFunc::GenerateSeriesTimestamp
3680            | TableFunc::GenerateSeriesTimestampTz
3681            | TableFunc::GuardSubquerySize { .. }
3682            | TableFunc::RepeatRowNonNegative
3683            | TableFunc::UnnestArray { .. }
3684            | TableFunc::UnnestList { .. }
3685            | TableFunc::UnnestMap { .. }
3686            | TableFunc::Wrap { .. }
3687            | TableFunc::GenerateSubscriptsArray
3688            | TableFunc::TabletizedScalar { .. }
3689            | TableFunc::RegexpMatches => Some(TableFunc::WithOrdinality(WithOrdinality {
3690                inner: Box::new(inner),
3691            })),
3692            // IMPORTANT: Before adding a new table function above, consider negative diffs:
3693            // `WithOrdinality::eval` will panic if the inner table function emits a negative diff.
3694            // (Note that negative diffs in the table function's _input_ don't matter. The table
3695            // function implementation doesn't see the input diffs, so the thing that matters here
3696            // is whether the table function itself can emit a negative diff.)
3697            TableFunc::RepeatRow // can produce negative diffs
3698            | TableFunc::WithOrdinality(_) => None, // no nesting of `WITH ORDINALITY` allowed
3699        }
3700    }
3701}
3702
3703impl TableFunc {
3704    /// Executes `self` on the given input row (`datums`).
3705    pub fn eval<'a>(
3706        &'a self,
3707        datums: &'a [Datum<'a>],
3708        temp_storage: &'a RowArena,
3709    ) -> Result<Box<dyn Iterator<Item = (Row, Diff)> + 'a>, EvalError> {
3710        if self.empty_on_null_input() && datums.iter().any(|d| d.is_null()) {
3711            return Ok(Box::new(vec![].into_iter()));
3712        }
3713        match self {
3714            TableFunc::AclExplode => Ok(Box::new(acl_explode(datums[0], temp_storage)?)),
3715            TableFunc::MzAclExplode => Ok(Box::new(mz_acl_explode(datums[0], temp_storage)?)),
3716            TableFunc::JsonbEach => Ok(Box::new(jsonb_each(datums[0]))),
3717            TableFunc::JsonbEachStringify => {
3718                Ok(Box::new(jsonb_each_stringify(datums[0], temp_storage)))
3719            }
3720            TableFunc::JsonbObjectKeys => Ok(Box::new(jsonb_object_keys(datums[0]))),
3721            TableFunc::JsonbArrayElements => Ok(Box::new(jsonb_array_elements(datums[0]))),
3722            TableFunc::JsonbArrayElementsStringify => Ok(Box::new(jsonb_array_elements_stringify(
3723                datums[0],
3724                temp_storage,
3725            ))),
3726            TableFunc::RegexpExtract(a) => Ok(Box::new(regexp_extract(datums[0], a).into_iter())),
3727            TableFunc::CsvExtract(n_cols) => Ok(Box::new(csv_extract(datums[0], *n_cols))),
3728            TableFunc::GenerateSeriesInt32 => {
3729                let res = generate_series(
3730                    datums[0].unwrap_int32(),
3731                    datums[1].unwrap_int32(),
3732                    datums[2].unwrap_int32(),
3733                )?;
3734                Ok(Box::new(res))
3735            }
3736            TableFunc::GenerateSeriesInt64 | TableFunc::GenerateSeriesUnoptimized => {
3737                let res = generate_series(
3738                    datums[0].unwrap_int64(),
3739                    datums[1].unwrap_int64(),
3740                    datums[2].unwrap_int64(),
3741                )?;
3742                Ok(Box::new(res))
3743            }
3744            TableFunc::GenerateSeriesTimestamp => {
3745                fn pass_through<'a>(d: CheckedTimestamp<NaiveDateTime>) -> Datum<'a> {
3746                    Datum::from(d)
3747                }
3748                let res = generate_series_ts(
3749                    datums[0].unwrap_timestamp(),
3750                    datums[1].unwrap_timestamp(),
3751                    datums[2].unwrap_interval(),
3752                    pass_through,
3753                )?;
3754                Ok(Box::new(res))
3755            }
3756            TableFunc::GenerateSeriesTimestampTz => {
3757                fn gen_ts_tz<'a>(d: CheckedTimestamp<DateTime<Utc>>) -> Datum<'a> {
3758                    Datum::from(d)
3759                }
3760                let res = generate_series_ts(
3761                    datums[0].unwrap_timestamptz(),
3762                    datums[1].unwrap_timestamptz(),
3763                    datums[2].unwrap_interval(),
3764                    gen_ts_tz,
3765                )?;
3766                Ok(Box::new(res))
3767            }
3768            TableFunc::GenerateSubscriptsArray => {
3769                generate_subscripts_array(datums[0], datums[1].unwrap_int32())
3770            }
3771            TableFunc::GuardSubquerySize { column_type: _ } => {
3772                // A subquery used as an expression may return at most one row;
3773                // for 0 or 1 we emit no rows and let the subquery's own output
3774                // flow through. Zero is benign, not "can't happen": constant
3775                // folding can prove the body empty and fold its count to a
3776                // literal `0`, which decorrelates to NULL via the outer lookup.
3777                let count = datums[0].unwrap_int64();
3778                if count > 1 {
3779                    Err(EvalError::MultipleRowsFromSubquery)
3780                } else if count < 0 {
3781                    // Would require negative multiplicities to reach the guard.
3782                    Err(EvalError::NegativeRowsFromSubquery)
3783                } else {
3784                    Ok(Box::new([].into_iter()))
3785                }
3786            }
3787            TableFunc::RepeatRow => Ok(Box::new(repeat_row(datums[0]).into_iter())),
3788            TableFunc::RepeatRowNonNegative => repeat_row_non_negative(datums[0]),
3789            TableFunc::UnnestArray { .. } => Ok(Box::new(unnest_array(datums[0]))),
3790            TableFunc::UnnestList { .. } => Ok(Box::new(unnest_list(datums[0]))),
3791            TableFunc::UnnestMap { .. } => Ok(Box::new(unnest_map(datums[0]))),
3792            TableFunc::Wrap { width, .. } => Ok(Box::new(wrap(datums, *width))),
3793            TableFunc::TabletizedScalar { .. } => {
3794                let r = Row::pack_slice(datums);
3795                Ok(Box::new(std::iter::once((r, Diff::ONE))))
3796            }
3797            TableFunc::RegexpMatches => Ok(Box::new(regexp_matches(datums)?)),
3798            TableFunc::WithOrdinality(func_with_ordinality) => {
3799                func_with_ordinality.eval(datums, temp_storage)
3800            }
3801        }
3802    }
3803
3804    pub fn output_sql_type(&self) -> SqlRelationType {
3805        let (column_types, keys) = match self {
3806            TableFunc::AclExplode => {
3807                let column_types = vec![
3808                    SqlScalarType::Oid.nullable(false),
3809                    SqlScalarType::Oid.nullable(false),
3810                    SqlScalarType::String.nullable(false),
3811                    SqlScalarType::Bool.nullable(false),
3812                ];
3813                let keys = vec![];
3814                (column_types, keys)
3815            }
3816            TableFunc::MzAclExplode => {
3817                let column_types = vec![
3818                    SqlScalarType::String.nullable(false),
3819                    SqlScalarType::String.nullable(false),
3820                    SqlScalarType::String.nullable(false),
3821                    SqlScalarType::Bool.nullable(false),
3822                ];
3823                let keys = vec![];
3824                (column_types, keys)
3825            }
3826            TableFunc::JsonbEach => {
3827                let column_types = vec![
3828                    SqlScalarType::String.nullable(false),
3829                    SqlScalarType::Jsonb.nullable(false),
3830                ];
3831                let keys = vec![];
3832                (column_types, keys)
3833            }
3834            TableFunc::JsonbEachStringify => {
3835                let column_types = vec![
3836                    SqlScalarType::String.nullable(false),
3837                    SqlScalarType::String.nullable(true),
3838                ];
3839                let keys = vec![];
3840                (column_types, keys)
3841            }
3842            TableFunc::JsonbObjectKeys => {
3843                let column_types = vec![SqlScalarType::String.nullable(false)];
3844                let keys = vec![];
3845                (column_types, keys)
3846            }
3847            TableFunc::JsonbArrayElements => {
3848                let column_types = vec![SqlScalarType::Jsonb.nullable(false)];
3849                let keys = vec![];
3850                (column_types, keys)
3851            }
3852            TableFunc::JsonbArrayElementsStringify => {
3853                let column_types = vec![SqlScalarType::String.nullable(true)];
3854                let keys = vec![];
3855                (column_types, keys)
3856            }
3857            TableFunc::RegexpExtract(a) => {
3858                let column_types = a
3859                    .capture_groups_iter()
3860                    .map(|cg| SqlScalarType::String.nullable(cg.nullable))
3861                    .collect();
3862                let keys = vec![];
3863                (column_types, keys)
3864            }
3865            TableFunc::CsvExtract(n_cols) => {
3866                let column_types = iter::repeat(SqlScalarType::String.nullable(false))
3867                    .take(*n_cols)
3868                    .collect();
3869                let keys = vec![];
3870                (column_types, keys)
3871            }
3872            TableFunc::GenerateSeriesInt32 => {
3873                let column_types = vec![SqlScalarType::Int32.nullable(false)];
3874                let keys = vec![vec![0]];
3875                (column_types, keys)
3876            }
3877            TableFunc::GenerateSeriesInt64 | TableFunc::GenerateSeriesUnoptimized => {
3878                let column_types = vec![SqlScalarType::Int64.nullable(false)];
3879                let keys = vec![vec![0]];
3880                (column_types, keys)
3881            }
3882            TableFunc::GenerateSeriesTimestamp => {
3883                let column_types =
3884                    vec![SqlScalarType::Timestamp { precision: None }.nullable(false)];
3885                let keys = vec![vec![0]];
3886                (column_types, keys)
3887            }
3888            TableFunc::GenerateSeriesTimestampTz => {
3889                let column_types =
3890                    vec![SqlScalarType::TimestampTz { precision: None }.nullable(false)];
3891                let keys = vec![vec![0]];
3892                (column_types, keys)
3893            }
3894            TableFunc::GenerateSubscriptsArray => {
3895                let column_types = vec![SqlScalarType::Int32.nullable(false)];
3896                let keys = vec![vec![0]];
3897                (column_types, keys)
3898            }
3899            TableFunc::GuardSubquerySize { column_type } => {
3900                let column_types = vec![column_type.clone().nullable(false)];
3901                let keys = vec![];
3902                (column_types, keys)
3903            }
3904            TableFunc::RepeatRow | TableFunc::RepeatRowNonNegative => {
3905                let column_types = vec![];
3906                let keys = vec![];
3907                (column_types, keys)
3908            }
3909            TableFunc::UnnestArray { el_typ } => {
3910                let column_types = vec![el_typ.clone().nullable(true)];
3911                let keys = vec![];
3912                (column_types, keys)
3913            }
3914            TableFunc::UnnestList { el_typ } => {
3915                let column_types = vec![el_typ.clone().nullable(true)];
3916                let keys = vec![];
3917                (column_types, keys)
3918            }
3919            TableFunc::UnnestMap { value_type } => {
3920                let column_types = vec![
3921                    SqlScalarType::String.nullable(false),
3922                    value_type.clone().nullable(true),
3923                ];
3924                let keys = vec![vec![0]];
3925                (column_types, keys)
3926            }
3927            TableFunc::Wrap { types, .. } => {
3928                let column_types = types.clone();
3929                let keys = vec![];
3930                (column_types, keys)
3931            }
3932            TableFunc::TabletizedScalar { relation, .. } => {
3933                return relation.clone();
3934            }
3935            TableFunc::RegexpMatches => {
3936                let column_types =
3937                    vec![SqlScalarType::Array(Box::new(SqlScalarType::String)).nullable(false)];
3938                let keys = vec![];
3939
3940                (column_types, keys)
3941            }
3942            TableFunc::WithOrdinality(WithOrdinality { inner }) => {
3943                let mut typ = inner.output_sql_type();
3944                // Add the ordinality column.
3945                typ.column_types.push(SqlScalarType::Int64.nullable(false));
3946                // The ordinality column is always a key.
3947                typ.keys.push(vec![typ.column_types.len() - 1]);
3948                (typ.column_types, typ.keys)
3949            }
3950        };
3951
3952        soft_assert_eq_no_log!(column_types.len(), self.output_arity());
3953
3954        if !keys.is_empty() {
3955            SqlRelationType::new(column_types).with_keys(keys)
3956        } else {
3957            SqlRelationType::new(column_types)
3958        }
3959    }
3960
3961    /// Computes the representation type of this table function.
3962    ///
3963    /// This is a wrapper around [`Self::output_sql_type`] that converts the result to a representation type.
3964    pub fn output_type(&self) -> ReprRelationType {
3965        ReprRelationType::from(&self.output_sql_type())
3966    }
3967
3968    pub fn output_arity(&self) -> usize {
3969        match self {
3970            TableFunc::AclExplode => 4,
3971            TableFunc::MzAclExplode => 4,
3972            TableFunc::JsonbEach => 2,
3973            TableFunc::JsonbEachStringify => 2,
3974            TableFunc::JsonbObjectKeys => 1,
3975            TableFunc::JsonbArrayElements => 1,
3976            TableFunc::JsonbArrayElementsStringify => 1,
3977            TableFunc::RegexpExtract(a) => a.capture_groups_len(),
3978            TableFunc::CsvExtract(n_cols) => *n_cols,
3979            TableFunc::GenerateSeriesInt32 => 1,
3980            TableFunc::GenerateSeriesInt64 => 1,
3981            TableFunc::GenerateSeriesUnoptimized => 1,
3982            TableFunc::GenerateSeriesTimestamp => 1,
3983            TableFunc::GenerateSeriesTimestampTz => 1,
3984            TableFunc::GenerateSubscriptsArray => 1,
3985            TableFunc::GuardSubquerySize { .. } => 1,
3986            TableFunc::RepeatRow => 0,
3987            TableFunc::RepeatRowNonNegative => 0,
3988            TableFunc::UnnestArray { .. } => 1,
3989            TableFunc::UnnestList { .. } => 1,
3990            TableFunc::UnnestMap { .. } => 2,
3991            TableFunc::Wrap { width, .. } => *width,
3992            TableFunc::TabletizedScalar { relation, .. } => relation.column_types.len(),
3993            TableFunc::RegexpMatches => 1,
3994            TableFunc::WithOrdinality(WithOrdinality { inner }) => inner.output_arity() + 1,
3995        }
3996    }
3997
3998    pub fn empty_on_null_input(&self) -> bool {
3999        match self {
4000            TableFunc::AclExplode
4001            | TableFunc::MzAclExplode
4002            | TableFunc::JsonbEach
4003            | TableFunc::JsonbEachStringify
4004            | TableFunc::JsonbObjectKeys
4005            | TableFunc::JsonbArrayElements
4006            | TableFunc::JsonbArrayElementsStringify
4007            | TableFunc::GenerateSeriesInt32
4008            | TableFunc::GenerateSeriesInt64
4009            | TableFunc::GenerateSeriesUnoptimized
4010            | TableFunc::GenerateSeriesTimestamp
4011            | TableFunc::GenerateSeriesTimestampTz
4012            | TableFunc::GenerateSubscriptsArray
4013            | TableFunc::RegexpExtract(_)
4014            | TableFunc::CsvExtract(_)
4015            | TableFunc::RepeatRow
4016            | TableFunc::RepeatRowNonNegative
4017            | TableFunc::UnnestArray { .. }
4018            | TableFunc::UnnestList { .. }
4019            | TableFunc::UnnestMap { .. }
4020            | TableFunc::RegexpMatches => true,
4021            TableFunc::GuardSubquerySize { .. } => false,
4022            TableFunc::Wrap { .. } => false,
4023            TableFunc::TabletizedScalar { .. } => false,
4024            TableFunc::WithOrdinality(WithOrdinality { inner }) => inner.empty_on_null_input(),
4025        }
4026    }
4027
4028    /// True iff the table function preserves the append-only property of its input.
4029    pub fn preserves_monotonicity(&self) -> bool {
4030        // Most variants preserve monotonicity, but all variants are enumerated to
4031        // ensure that added variants at least check that this is the case.
4032        match self {
4033            TableFunc::AclExplode => false,
4034            TableFunc::MzAclExplode => false,
4035            TableFunc::JsonbEach => true,
4036            TableFunc::JsonbEachStringify => true,
4037            TableFunc::JsonbObjectKeys => true,
4038            TableFunc::JsonbArrayElements => true,
4039            TableFunc::JsonbArrayElementsStringify => true,
4040            TableFunc::RegexpExtract(_) => true,
4041            TableFunc::CsvExtract(_) => true,
4042            TableFunc::GenerateSeriesInt32 => true,
4043            TableFunc::GenerateSeriesInt64 => true,
4044            TableFunc::GenerateSeriesUnoptimized => true,
4045            TableFunc::GenerateSeriesTimestamp => true,
4046            TableFunc::GenerateSeriesTimestampTz => true,
4047            TableFunc::GenerateSubscriptsArray => true,
4048            TableFunc::RepeatRow => false,
4049            TableFunc::RepeatRowNonNegative => true,
4050            TableFunc::UnnestArray { .. } => true,
4051            TableFunc::UnnestList { .. } => true,
4052            TableFunc::UnnestMap { .. } => true,
4053            TableFunc::Wrap { .. } => true,
4054            TableFunc::TabletizedScalar { .. } => true,
4055            TableFunc::RegexpMatches => true,
4056            TableFunc::GuardSubquerySize { .. } => false,
4057            TableFunc::WithOrdinality(WithOrdinality { inner }) => inner.preserves_monotonicity(),
4058        }
4059    }
4060}
4061
4062impl fmt::Display for TableFunc {
4063    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
4064        match self {
4065            TableFunc::AclExplode => f.write_str("aclexplode"),
4066            TableFunc::MzAclExplode => f.write_str("mz_aclexplode"),
4067            TableFunc::JsonbEach => f.write_str("jsonb_each"),
4068            TableFunc::JsonbEachStringify => f.write_str("jsonb_each_text"),
4069            TableFunc::JsonbObjectKeys => f.write_str("jsonb_object_keys"),
4070            TableFunc::JsonbArrayElements => f.write_str("jsonb_array_elements"),
4071            TableFunc::JsonbArrayElementsStringify => f.write_str("jsonb_array_elements_text"),
4072            TableFunc::RegexpExtract(a) => write!(f, "regexp_extract({:?}, _)", a.0),
4073            TableFunc::CsvExtract(n_cols) => write!(f, "csv_extract({}, _)", n_cols),
4074            TableFunc::GenerateSeriesInt32 => f.write_str("generate_series"),
4075            TableFunc::GenerateSeriesInt64 => f.write_str("generate_series"),
4076            TableFunc::GenerateSeriesUnoptimized => f.write_str("generate_series_unoptimized"),
4077            TableFunc::GenerateSeriesTimestamp => f.write_str("generate_series"),
4078            TableFunc::GenerateSeriesTimestampTz => f.write_str("generate_series"),
4079            TableFunc::GenerateSubscriptsArray => f.write_str("generate_subscripts"),
4080            TableFunc::GuardSubquerySize { .. } => f.write_str("guard_subquery_size"),
4081            TableFunc::RepeatRow => f.write_str(REPEAT_ROW_NAME),
4082            TableFunc::RepeatRowNonNegative => f.write_str("repeat_row_non_negative"),
4083            TableFunc::UnnestArray { .. } => f.write_str("unnest_array"),
4084            TableFunc::UnnestList { .. } => f.write_str("unnest_list"),
4085            TableFunc::UnnestMap { .. } => f.write_str("unnest_map"),
4086            TableFunc::Wrap { width, .. } => write!(f, "wrap{}", width),
4087            TableFunc::TabletizedScalar { name, .. } => f.write_str(name),
4088            TableFunc::RegexpMatches => write!(f, "regexp_matches(_, _, _)"),
4089            TableFunc::WithOrdinality(WithOrdinality { inner }) => {
4090                write!(f, "{}[with_ordinality]", inner)
4091            }
4092        }
4093    }
4094}
4095
4096impl WithOrdinality {
4097    /// Executes the `self.inner` table function on the given input row (`datums`), and zips
4098    /// 1, 2, 3, ... to the result as a new column. We need to expand rows with non-1 diffs into the
4099    /// corresponding number of rows with unit diffs, because the ordinality column will have
4100    /// different values for each copy.
4101    ///
4102    /// # Panics
4103    ///
4104    /// Panics if the `inner` table function emits a negative diff.
4105    fn eval<'a>(
4106        &'a self,
4107        datums: &'a [Datum<'a>],
4108        temp_storage: &'a RowArena,
4109    ) -> Result<Box<dyn Iterator<Item = (Row, Diff)> + 'a>, EvalError> {
4110        let mut next_ordinal: i64 = 1;
4111        let it = self
4112            .inner
4113            .eval(datums, temp_storage)?
4114            .flat_map(move |(mut row, diff)| {
4115                let diff = diff.into_inner();
4116                // WITH ORDINALITY is not well-defined for negative diffs. This is ok, and
4117                // `TableFunc::with_ordinality` refuses to wrap such table functions in
4118                // `WithOrdinality` that can emit negative diffs, e.g., `repeat_row`.
4119                //
4120                // (Note that we don't need to worry about negative diffs in FlatMap's input,
4121                // because the diff of the input of the FlatMap is factored in after we return from
4122                // here.)
4123                assert!(diff >= 0);
4124                // The ordinals that will be associated with this row.
4125                let mut ordinals = next_ordinal..(next_ordinal + diff);
4126                next_ordinal += diff;
4127                // The maximum byte capacity we need for the original row and its ordinal.
4128                let cap = row.data_len() + datum_size(&Datum::Int64(next_ordinal));
4129                iter::from_fn(move || {
4130                    let ordinal = ordinals.next()?;
4131                    let mut row = if ordinals.is_empty() {
4132                        // This is the last row, so no need to clone. (Most table functions emit
4133                        // only 1 diffs, so this completely avoids cloning in most cases.)
4134                        std::mem::take(&mut row)
4135                    } else {
4136                        let mut new_row = Row::with_capacity(cap);
4137                        new_row.clone_from(&row);
4138                        new_row
4139                    };
4140                    RowPacker::for_existing_row(&mut row).push(Datum::Int64(ordinal));
4141                    Some((row, Diff::ONE))
4142                })
4143            });
4144        Ok(Box::new(it))
4145    }
4146}
4147
4148pub const REPEAT_ROW_NAME: &str = "repeat_row";
4149
4150#[cfg(test)]
4151mod tests {
4152    use mz_repr::{Datum, RowArena, SqlScalarType};
4153
4154    use super::TableFunc;
4155    use crate::EvalError;
4156
4157    /// 0 and 1 are valid (no guard rows), >1 errors with
4158    /// `MultipleRowsFromSubquery`, <0 with `NegativeRowsFromSubquery`. Zero is
4159    /// legitimate, not "can't happen": constant folding can fold an empty
4160    /// subquery's count to `0`, which must not panic (regression for #37049).
4161    #[mz_ore::test]
4162    fn guard_subquery_size_accepts_zero_and_one() {
4163        let func = TableFunc::GuardSubquerySize {
4164            column_type: SqlScalarType::Int64,
4165        };
4166        let temp_storage = RowArena::new();
4167
4168        for count in [0_i64, 1] {
4169            let rows = func
4170                .eval(&[Datum::Int64(count)], &temp_storage)
4171                .unwrap_or_else(|e| panic!("count {count} should be accepted, got {e:?}"))
4172                .count();
4173            assert_eq!(rows, 0, "count {count} should emit no guard rows");
4174        }
4175
4176        assert_eq!(
4177            func.eval(&[Datum::Int64(2)], &temp_storage).err(),
4178            Some(EvalError::MultipleRowsFromSubquery),
4179        );
4180        assert_eq!(
4181            func.eval(&[Datum::Int64(-1)], &temp_storage).err(),
4182            Some(EvalError::NegativeRowsFromSubquery),
4183        );
4184    }
4185}