mz_repr_test_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities to build objects from the `repr` crate for unit testing.
11//!
12//! These test utilities are relied by crates other than `repr`.
13
14use chrono::NaiveDateTime;
15use mz_lowertest::deserialize_optional_generic;
16use mz_ore::str::StrExt;
17use mz_repr::adt::numeric::Numeric;
18use mz_repr::adt::timestamp::CheckedTimestamp;
19use mz_repr::strconv::parse_jsonb;
20use mz_repr::{Datum, Row, RowArena, SqlScalarType};
21use proc_macro2::TokenTree;
22
23/* #endregion */
24
25fn parse_litval<'a, F>(litval: &'a str, littyp: &str) -> Result<F, String>
26where
27    F: std::str::FromStr,
28    F::Err: ToString,
29{
30    litval.parse::<F>().map_err(|e| {
31        format!(
32            "error when parsing {} into {}: {}",
33            litval,
34            littyp,
35            e.to_string()
36        )
37    })
38}
39
40/// Constructs a `Row` from a sequence of `litval` and `littyp`.
41///
42/// See [get_scalar_type_or_default] for creating a `SqlScalarType`.
43///
44/// Generally, each `litval` can be parsed into a Datum in the manner you would
45/// imagine. Exceptions:
46/// * A Timestamp should be in the format `"\"%Y-%m-%d %H:%M:%S%.f\""` or
47///   `"\"%Y-%m-%d %H:%M:%S\""`
48///
49/// Not all types are supported yet. Currently supported types:
50/// * string, bool, timestamp
51/// * all flavors of numeric types
52pub fn test_spec_to_row<'a, I>(datum_iter: I) -> Result<Row, String>
53where
54    I: Iterator<Item = (&'a str, &'a SqlScalarType)>,
55{
56    let temp_storage = RowArena::new();
57    Row::try_pack(datum_iter.map(|(litval, littyp)| {
58        if litval == "null" {
59            Ok(Datum::Null)
60        } else {
61            match littyp {
62                SqlScalarType::Bool => Ok(Datum::from(parse_litval::<bool>(litval, "bool")?)),
63                SqlScalarType::Numeric { .. } => {
64                    Ok(Datum::from(parse_litval::<Numeric>(litval, "Numeric")?))
65                }
66                SqlScalarType::Int16 => Ok(Datum::from(parse_litval::<i16>(litval, "i16")?)),
67                SqlScalarType::Int32 => Ok(Datum::from(parse_litval::<i32>(litval, "i32")?)),
68                SqlScalarType::Int64 => Ok(Datum::from(parse_litval::<i64>(litval, "i64")?)),
69                SqlScalarType::Float32 => Ok(Datum::from(parse_litval::<f32>(litval, "f32")?)),
70                SqlScalarType::Float64 => Ok(Datum::from(parse_litval::<f64>(litval, "f64")?)),
71                SqlScalarType::String => Ok(Datum::from(
72                    temp_storage.push_string(mz_lowertest::unquote(litval)),
73                )),
74                SqlScalarType::Timestamp { .. } => {
75                    let datetime = if litval.contains('.') {
76                        NaiveDateTime::parse_from_str(litval, "\"%Y-%m-%d %H:%M:%S%.f\"")
77                    } else {
78                        NaiveDateTime::parse_from_str(litval, "\"%Y-%m-%d %H:%M:%S\"")
79                    };
80                    Ok(Datum::from(
81                        CheckedTimestamp::from_timestamplike(
82                            datetime
83                                .map_err(|e| format!("Error while parsing NaiveDateTime: {}", e))?,
84                        )
85                        .unwrap(),
86                    ))
87                }
88                SqlScalarType::Jsonb => parse_jsonb(&mz_lowertest::unquote(litval))
89                    .map(|jsonb| temp_storage.push_unary_row(jsonb.into_row()))
90                    .map_err(|parse| format!("Invalid JSON literal: {:?}", parse)),
91                _ => Err(format!("Unsupported literal type {:?}", littyp)),
92            }
93        }
94    }))
95}
96
97/// Convert a Datum to a String such that [test_spec_to_row] can convert the
98/// String back into a row containing the same Datum.
99///
100/// Currently supports only Datums supported by [test_spec_to_row].
101pub fn datum_to_test_spec(datum: Datum) -> String {
102    let result = format!("{}", datum);
103    match datum {
104        Datum::Timestamp(_) => result.quoted().to_string(),
105        _ => result,
106    }
107}
108
109/// Parses `SqlScalarType` from `scalar_type_stream` or infers it from `litval`
110///
111/// See [mz_lowertest::to_json] for the syntax for specifying a `SqlScalarType`.
112/// If `scalar_type_stream` is empty, will attempt to guess a `SqlScalarType` for
113/// the literal:
114/// * If `litval` is "true", "false", or "null", will return `Bool`.
115/// * Else if starts with `'"'`, will return String.
116/// * Else if contains `'.'`, will return Float64.
117/// * Otherwise, returns Int64.
118pub fn get_scalar_type_or_default<I>(
119    litval: &str,
120    scalar_type_stream: &mut I,
121) -> Result<SqlScalarType, String>
122where
123    I: Iterator<Item = TokenTree>,
124{
125    let typ: Option<SqlScalarType> =
126        deserialize_optional_generic(scalar_type_stream, "SqlScalarType")?;
127    match typ {
128        Some(typ) => Ok(typ),
129        None => {
130            if ["true", "false", "null"].contains(&litval) {
131                Ok(SqlScalarType::Bool)
132            } else if litval.starts_with('\"') {
133                Ok(SqlScalarType::String)
134            } else if litval.contains('.') {
135                Ok(SqlScalarType::Float64)
136            } else {
137                Ok(SqlScalarType::Int64)
138            }
139        }
140    }
141}
142
143/// If the stream starts with a sequence of tokens that can be parsed as a datum,
144/// return those tokens as one string.
145///
146/// Sequences of tokens that can be parsed as a datum:
147/// * A Literal token, which is anything in quotations or a positive number
148/// * An null, false, or true Ident token
149/// * Punct(-) + a literal token
150///
151/// If the stream starts with a sequence of tokens that can be parsed as a
152/// datum, 1) returns Ok(Some(..)) 2) advances the stream to the first token
153/// that is not part of the sequence.
154/// If the stream does not start with tokens that can be parsed as a datum:
155/// * Return Ok(None) if `rest_of_stream` has not been advanced.
156/// * Returns Err(..) otherwise.
157pub fn extract_literal_string<I>(
158    first_arg: &TokenTree,
159    rest_of_stream: &mut I,
160) -> Result<Option<String>, String>
161where
162    I: Iterator<Item = TokenTree>,
163{
164    match first_arg {
165        TokenTree::Ident(ident) => {
166            if ["true", "false", "null"].contains(&&ident.to_string()[..]) {
167                Ok(Some(ident.to_string()))
168            } else {
169                Ok(None)
170            }
171        }
172        TokenTree::Literal(literal) => Ok(Some(literal.to_string())),
173        TokenTree::Punct(punct) if punct.as_char() == '-' => {
174            match rest_of_stream.next() {
175                Some(TokenTree::Literal(literal)) => {
176                    Ok(Some(format!("{}{}", punct.as_char(), literal)))
177                }
178                None => Ok(None),
179                // Must error instead of handling the tokens using default
180                // behavior since `stream_iter` has advanced.
181                Some(other) => Err(format!(
182                    "`{}` `{}` is not a valid literal",
183                    punct.as_char(),
184                    other
185                )),
186            }
187        }
188        _ => Ok(None),
189    }
190}
191
192/// Parse a token as a vec of strings that can be parsed as datums in a row.
193///
194/// The token is assumed to be of the form `[datum1 datum2 .. datumn]`.
195pub fn parse_vec_of_literals(token: &TokenTree) -> Result<Vec<String>, String> {
196    match token {
197        TokenTree::Group(group) => {
198            let mut inner_iter = group.stream().into_iter();
199            let mut result = Vec::new();
200            while let Some(symbol) = inner_iter.next() {
201                match extract_literal_string(&symbol, &mut inner_iter)? {
202                    Some(dat) => result.push(dat),
203                    None => {
204                        return Err(format!(
205                            "TokenTree `{}` cannot be interpreted as a literal.",
206                            symbol
207                        ));
208                    }
209                }
210            }
211            Ok(result)
212        }
213        invalid => Err(format!(
214            "TokenTree `{}` cannot be parsed as a vec of literals",
215            invalid
216        )),
217    }
218}