mz_expr/scalar/func/impls/
list.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use mz_lowertest::MzReflect;
13use mz_repr::{AsColumnType, Datum, DatumList, Row, RowArena, SqlColumnType, SqlScalarType};
14use serde::{Deserialize, Serialize};
15
16use crate::func::binary::EagerBinaryFunc;
17use crate::scalar::func::{LazyUnaryFunc, stringify_datum};
18use crate::{EvalError, MirScalarExpr};
19
20#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
21pub struct CastListToString {
22    pub ty: SqlScalarType,
23}
24
25impl LazyUnaryFunc for CastListToString {
26    fn eval<'a>(
27        &'a self,
28        datums: &[Datum<'a>],
29        temp_storage: &'a RowArena,
30        a: &'a MirScalarExpr,
31    ) -> Result<Datum<'a>, EvalError> {
32        let a = a.eval(datums, temp_storage)?;
33        if a.is_null() {
34            return Ok(Datum::Null);
35        }
36        let mut buf = String::new();
37        stringify_datum(&mut buf, a, &self.ty)?;
38        Ok(Datum::String(temp_storage.push_string(buf)))
39    }
40
41    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
42        SqlScalarType::String.nullable(input_type.nullable)
43    }
44
45    fn propagates_nulls(&self) -> bool {
46        true
47    }
48
49    fn introduces_nulls(&self) -> bool {
50        false
51    }
52
53    fn preserves_uniqueness(&self) -> bool {
54        true
55    }
56
57    fn inverse(&self) -> Option<crate::UnaryFunc> {
58        // TODO? if typeconv was in expr, we could determine this
59        None
60    }
61
62    fn is_monotone(&self) -> bool {
63        false
64    }
65}
66
67impl fmt::Display for CastListToString {
68    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69        f.write_str("listtostr")
70    }
71}
72
73#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
74pub struct CastListToJsonb {
75    pub cast_element: Box<MirScalarExpr>,
76}
77
78impl LazyUnaryFunc for CastListToJsonb {
79    fn eval<'a>(
80        &'a self,
81        datums: &[Datum<'a>],
82        temp_storage: &'a RowArena,
83        a: &'a MirScalarExpr,
84    ) -> Result<Datum<'a>, EvalError> {
85        let a = a.eval(datums, temp_storage)?;
86        if a.is_null() {
87            return Ok(Datum::Null);
88        }
89        let mut row = Row::default();
90        row.packer().push_list_with(|packer| {
91            for elem in a.unwrap_list().iter() {
92                let elem = match self.cast_element.eval(&[elem], temp_storage)? {
93                    Datum::Null => Datum::JsonNull,
94                    d => d,
95                };
96                packer.push(elem);
97            }
98            Ok::<_, EvalError>(())
99        })?;
100        Ok(temp_storage.push_unary_row(row))
101    }
102
103    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
104        SqlScalarType::Jsonb.nullable(input_type.nullable)
105    }
106
107    fn propagates_nulls(&self) -> bool {
108        true
109    }
110
111    fn introduces_nulls(&self) -> bool {
112        false
113    }
114
115    fn preserves_uniqueness(&self) -> bool {
116        true
117    }
118
119    fn inverse(&self) -> Option<crate::UnaryFunc> {
120        // TODO? If we moved typeconv into `expr` we could determine the right
121        // inverse of this.
122        None
123    }
124
125    fn is_monotone(&self) -> bool {
126        false
127    }
128}
129
130impl fmt::Display for CastListToJsonb {
131    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132        f.write_str("listtojsonb")
133    }
134}
135
136/// Casts between two list types by casting each element of `a` ("list1") using
137/// `cast_expr` and collecting the results into a new list ("list2").
138#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
139pub struct CastList1ToList2 {
140    /// List2's type
141    pub return_ty: SqlScalarType,
142    /// The expression to cast List1's elements to List2's elements' type
143    pub cast_expr: Box<MirScalarExpr>,
144}
145
146impl LazyUnaryFunc for CastList1ToList2 {
147    fn eval<'a>(
148        &'a self,
149        datums: &[Datum<'a>],
150        temp_storage: &'a RowArena,
151        a: &'a MirScalarExpr,
152    ) -> Result<Datum<'a>, EvalError> {
153        let a = a.eval(datums, temp_storage)?;
154        if a.is_null() {
155            return Ok(Datum::Null);
156        }
157        let mut cast_datums = Vec::new();
158        for el in a.unwrap_list().iter() {
159            // `cast_expr` is evaluated as an expression that casts the
160            // first column in `datums` (i.e. `datums[0]`) from the list elements'
161            // current type to a target type.
162            cast_datums.push(self.cast_expr.eval(&[el], temp_storage)?);
163        }
164
165        Ok(temp_storage.make_datum(|packer| packer.push_list(cast_datums)))
166    }
167
168    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
169        self.return_ty
170            .without_modifiers()
171            .nullable(input_type.nullable)
172    }
173
174    fn propagates_nulls(&self) -> bool {
175        true
176    }
177
178    fn introduces_nulls(&self) -> bool {
179        false
180    }
181
182    fn preserves_uniqueness(&self) -> bool {
183        false
184    }
185
186    fn inverse(&self) -> Option<crate::UnaryFunc> {
187        // TODO: this could be figured out--might be easier after enum dispatch?
188        None
189    }
190
191    fn is_monotone(&self) -> bool {
192        false
193    }
194}
195
196impl fmt::Display for CastList1ToList2 {
197    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
198        f.write_str("list1tolist2")
199    }
200}
201
202#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
203pub struct ListLength;
204
205impl LazyUnaryFunc for ListLength {
206    fn eval<'a>(
207        &'a self,
208        datums: &[Datum<'a>],
209        temp_storage: &'a RowArena,
210        a: &'a MirScalarExpr,
211    ) -> Result<Datum<'a>, EvalError> {
212        let a = a.eval(datums, temp_storage)?;
213        if a.is_null() {
214            return Ok(Datum::Null);
215        }
216        let count = a.unwrap_list().iter().count();
217        match count.try_into() {
218            Ok(c) => Ok(Datum::Int32(c)),
219            Err(_) => Err(EvalError::Int32OutOfRange(count.to_string().into())),
220        }
221    }
222
223    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
224        SqlScalarType::Int32.nullable(input_type.nullable)
225    }
226
227    fn propagates_nulls(&self) -> bool {
228        true
229    }
230
231    fn introduces_nulls(&self) -> bool {
232        false
233    }
234
235    fn preserves_uniqueness(&self) -> bool {
236        false
237    }
238
239    fn inverse(&self) -> Option<crate::UnaryFunc> {
240        None
241    }
242
243    fn is_monotone(&self) -> bool {
244        false
245    }
246}
247
248impl fmt::Display for ListLength {
249    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
250        f.write_str("list_length")
251    }
252}
253
254/// The `list_length_max` implementation.
255///
256/// We're not deriving `sqlfunc` here because we need to pass in the `max_layer` parameter.
257#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
258pub struct ListLengthMax {
259    /// Maximal allowed layer to query.
260    pub max_layer: usize,
261}
262impl<'a> EagerBinaryFunc<'a> for ListLengthMax {
263    type Input1 = DatumList<'a>;
264    type Input2 = i64;
265    type Output = Result<Option<i32>, EvalError>;
266    // TODO(benesch): remove potentially dangerous usage of `as`.
267    #[allow(clippy::as_conversions)]
268    fn call(&self, a: Self::Input1, b: Self::Input2, _: &'a RowArena) -> Self::Output {
269        fn max_len_on_layer(i: DatumList<'_>, on_layer: i64) -> Option<usize> {
270            let mut i = i.iter();
271            if on_layer > 1 {
272                let mut max_len = None;
273                while let Some(Datum::List(i)) = i.next() {
274                    max_len = std::cmp::max(max_len_on_layer(i, on_layer - 1), max_len);
275                }
276                max_len
277            } else {
278                Some(i.count())
279            }
280        }
281        if b as usize > self.max_layer || b < 1 {
282            Err(EvalError::InvalidLayer {
283                max_layer: self.max_layer,
284                val: b,
285            })
286        } else {
287            match max_len_on_layer(a, b) {
288                Some(l) => match l.try_into() {
289                    Ok(c) => Ok(Some(c)),
290                    Err(_) => Err(EvalError::Int32OutOfRange(l.to_string().into())),
291                },
292                None => Ok(None),
293            }
294        }
295    }
296    fn output_type(
297        &self,
298        input_type_a: SqlColumnType,
299        input_type_b: SqlColumnType,
300    ) -> SqlColumnType {
301        let output = Self::Output::as_column_type();
302        let propagates_nulls = self.propagates_nulls();
303        let nullable = output.nullable;
304        output.nullable(
305            nullable || (propagates_nulls && (input_type_a.nullable || input_type_b.nullable)),
306        )
307    }
308}
309impl fmt::Display for ListLengthMax {
310    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
311        f.write_str("list_length_max")
312    }
313}