Skip to main content

mz_expr/scalar/func/impls/
list.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use mz_expr_derive::sqlfunc;
13use mz_lowertest::MzReflect;
14use mz_repr::{AsColumnType, Datum, DatumList, Row, RowArena, SqlColumnType, SqlScalarType};
15use serde::{Deserialize, Serialize};
16
17use crate::func::binary::EagerBinaryFunc;
18use crate::scalar::func::{LazyUnaryFunc, stringify_datum};
19use crate::{EvalError, MirScalarExpr};
20
21#[derive(
22    Ord,
23    PartialOrd,
24    Clone,
25    Debug,
26    Eq,
27    PartialEq,
28    Serialize,
29    Deserialize,
30    Hash,
31    MzReflect
32)]
33pub struct CastListToString {
34    pub ty: SqlScalarType,
35}
36
37impl LazyUnaryFunc for CastListToString {
38    fn eval<'a>(
39        &'a self,
40        datums: &[Datum<'a>],
41        temp_storage: &'a RowArena,
42        a: &'a MirScalarExpr,
43    ) -> Result<Datum<'a>, EvalError> {
44        let a = a.eval(datums, temp_storage)?;
45        if a.is_null() {
46            return Ok(Datum::Null);
47        }
48        let mut buf = String::new();
49        stringify_datum(&mut buf, a, &self.ty)?;
50        Ok(Datum::String(temp_storage.push_string(buf)))
51    }
52
53    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
54        SqlScalarType::String.nullable(input_type.nullable)
55    }
56
57    fn propagates_nulls(&self) -> bool {
58        true
59    }
60
61    fn introduces_nulls(&self) -> bool {
62        false
63    }
64
65    fn preserves_uniqueness(&self) -> bool {
66        true
67    }
68
69    fn inverse(&self) -> Option<crate::UnaryFunc> {
70        // TODO? if typeconv was in expr, we could determine this
71        None
72    }
73
74    fn is_monotone(&self) -> bool {
75        false
76    }
77}
78
79impl fmt::Display for CastListToString {
80    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81        f.write_str("listtostr")
82    }
83}
84
85#[derive(
86    Ord,
87    PartialOrd,
88    Clone,
89    Debug,
90    Eq,
91    PartialEq,
92    Serialize,
93    Deserialize,
94    Hash,
95    MzReflect
96)]
97pub struct CastListToJsonb {
98    pub cast_element: Box<MirScalarExpr>,
99}
100
101impl LazyUnaryFunc for CastListToJsonb {
102    fn eval<'a>(
103        &'a self,
104        datums: &[Datum<'a>],
105        temp_storage: &'a RowArena,
106        a: &'a MirScalarExpr,
107    ) -> Result<Datum<'a>, EvalError> {
108        let a = a.eval(datums, temp_storage)?;
109        if a.is_null() {
110            return Ok(Datum::Null);
111        }
112        let mut row = Row::default();
113        row.packer().push_list_with(|packer| {
114            for elem in a.unwrap_list().iter() {
115                let elem = match self.cast_element.eval(&[elem], temp_storage)? {
116                    Datum::Null => Datum::JsonNull,
117                    d => d,
118                };
119                packer.push(elem);
120            }
121            Ok::<_, EvalError>(())
122        })?;
123        Ok(temp_storage.push_unary_row(row))
124    }
125
126    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
127        SqlScalarType::Jsonb.nullable(input_type.nullable)
128    }
129
130    fn propagates_nulls(&self) -> bool {
131        true
132    }
133
134    fn introduces_nulls(&self) -> bool {
135        false
136    }
137
138    fn preserves_uniqueness(&self) -> bool {
139        true
140    }
141
142    fn inverse(&self) -> Option<crate::UnaryFunc> {
143        // TODO? If we moved typeconv into `expr` we could determine the right
144        // inverse of this.
145        None
146    }
147
148    fn is_monotone(&self) -> bool {
149        false
150    }
151}
152
153impl fmt::Display for CastListToJsonb {
154    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
155        f.write_str("listtojsonb")
156    }
157}
158
159/// Casts between two list types by casting each element of `a` ("list1") using
160/// `cast_expr` and collecting the results into a new list ("list2").
161#[derive(
162    Ord,
163    PartialOrd,
164    Clone,
165    Debug,
166    Eq,
167    PartialEq,
168    Serialize,
169    Deserialize,
170    Hash,
171    MzReflect
172)]
173pub struct CastList1ToList2 {
174    /// List2's type
175    pub return_ty: SqlScalarType,
176    /// The expression to cast List1's elements to List2's elements' type
177    pub cast_expr: Box<MirScalarExpr>,
178}
179
180impl LazyUnaryFunc for CastList1ToList2 {
181    fn eval<'a>(
182        &'a self,
183        datums: &[Datum<'a>],
184        temp_storage: &'a RowArena,
185        a: &'a MirScalarExpr,
186    ) -> Result<Datum<'a>, EvalError> {
187        let a = a.eval(datums, temp_storage)?;
188        if a.is_null() {
189            return Ok(Datum::Null);
190        }
191        let mut cast_datums = Vec::new();
192        for el in a.unwrap_list().iter() {
193            // `cast_expr` is evaluated as an expression that casts the
194            // first column in `datums` (i.e. `datums[0]`) from the list elements'
195            // current type to a target type.
196            cast_datums.push(self.cast_expr.eval(&[el], temp_storage)?);
197        }
198
199        Ok(temp_storage.make_datum(|packer| packer.push_list(cast_datums)))
200    }
201
202    fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
203        self.return_ty
204            .without_modifiers()
205            .nullable(input_type.nullable)
206    }
207
208    fn propagates_nulls(&self) -> bool {
209        true
210    }
211
212    fn introduces_nulls(&self) -> bool {
213        false
214    }
215
216    fn preserves_uniqueness(&self) -> bool {
217        false
218    }
219
220    fn inverse(&self) -> Option<crate::UnaryFunc> {
221        // TODO: this could be figured out--might be easier after enum dispatch?
222        None
223    }
224
225    fn is_monotone(&self) -> bool {
226        false
227    }
228}
229
230impl fmt::Display for CastList1ToList2 {
231    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
232        f.write_str("list1tolist2")
233    }
234}
235
236#[sqlfunc(sqlname = "list_length")]
237fn list_length<'a>(a: DatumList<'a>) -> Result<i32, EvalError> {
238    let count = a.iter().count();
239    count
240        .try_into()
241        .map_err(|_| EvalError::Int32OutOfRange(count.to_string().into()))
242}
243
244/// The `list_length_max` implementation.
245///
246/// We're not deriving `sqlfunc` here because we need to pass in the `max_layer` parameter.
247#[derive(
248    Ord,
249    PartialOrd,
250    Clone,
251    Debug,
252    Eq,
253    PartialEq,
254    Serialize,
255    Deserialize,
256    Hash,
257    MzReflect
258)]
259pub struct ListLengthMax {
260    /// Maximal allowed layer to query.
261    pub max_layer: usize,
262}
263impl EagerBinaryFunc for ListLengthMax {
264    type Input<'a> = (DatumList<'a>, i64);
265    type Output<'a> = Result<Option<i32>, EvalError>;
266    // TODO(benesch): remove potentially dangerous usage of `as`.
267    #[allow(clippy::as_conversions)]
268    fn call<'a>(&self, (a, b): Self::Input<'a>, _: &'a RowArena) -> Self::Output<'a> {
269        fn max_len_on_layer(i: DatumList<'_>, on_layer: i64) -> Option<usize> {
270            let mut i = i.iter();
271            if on_layer > 1 {
272                let mut max_len = None;
273                while let Some(Datum::List(i)) = i.next() {
274                    max_len = std::cmp::max(max_len_on_layer(i, on_layer - 1), max_len);
275                }
276                max_len
277            } else {
278                Some(i.count())
279            }
280        }
281        if b as usize > self.max_layer || b < 1 {
282            Err(EvalError::InvalidLayer {
283                max_layer: self.max_layer,
284                val: b,
285            })
286        } else {
287            match max_len_on_layer(a, b) {
288                Some(l) => match l.try_into() {
289                    Ok(c) => Ok(Some(c)),
290                    Err(_) => Err(EvalError::Int32OutOfRange(l.to_string().into())),
291                },
292                None => Ok(None),
293            }
294        }
295    }
296    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
297        let output = Self::Output::as_column_type();
298        let propagates_nulls = self.propagates_nulls();
299        let nullable = output.nullable;
300        let input_nullable = input_types.iter().any(|t| t.nullable);
301        output.nullable(nullable || (propagates_nulls && input_nullable))
302    }
303}
304impl fmt::Display for ListLengthMax {
305    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
306        f.write_str("list_length_max")
307    }
308}