mz_expr/scalar/func/impls/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_lowertest::MzReflect;
14use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
15use mz_repr::{ColumnType, ScalarType, strconv};
16use proptest_derive::Arbitrary;
17use serde::{Deserialize, Serialize};
18
19use crate::EvalError;
20use crate::scalar::func::EagerUnaryFunc;
21
22sqlfunc!(
23    #[sqlname = "-"]
24    #[preserves_uniqueness = true]
25    #[inverse = to_unary!(NegNumeric)]
26    #[is_monotone = true]
27    fn neg_numeric(mut a: Numeric) -> Numeric {
28        numeric::cx_datum().neg(&mut a);
29        numeric::munge_numeric(&mut a).unwrap();
30        a
31    }
32);
33
34sqlfunc!(
35    #[sqlname = "abs"]
36    fn abs_numeric(mut a: Numeric) -> Numeric {
37        numeric::cx_datum().abs(&mut a);
38        a
39    }
40);
41
42sqlfunc!(
43    #[sqlname = "ceilnumeric"]
44    #[is_monotone = true]
45    fn ceil_numeric(mut a: Numeric) -> Numeric {
46        // ceil will be nop if has no fractional digits.
47        if a.exponent() >= 0 {
48            return a;
49        }
50        let mut cx = numeric::cx_datum();
51        cx.set_rounding(Rounding::Ceiling);
52        cx.round(&mut a);
53        numeric::munge_numeric(&mut a).unwrap();
54        a
55    }
56);
57
58sqlfunc!(
59    #[sqlname = "expnumeric"]
60    fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
61        let mut cx = numeric::cx_datum();
62        cx.exp(&mut a);
63        let cx_status = cx.status();
64        if cx_status.overflow() {
65            Err(EvalError::FloatOverflow)
66        } else if cx_status.subnormal() {
67            Err(EvalError::FloatUnderflow)
68        } else {
69            numeric::munge_numeric(&mut a).unwrap();
70            Ok(a)
71        }
72    }
73);
74
75sqlfunc!(
76    #[sqlname = "floornumeric"]
77    #[is_monotone = true]
78    fn floor_numeric(mut a: Numeric) -> Numeric {
79        // floor will be nop if has no fractional digits.
80        if a.exponent() >= 0 {
81            return a;
82        }
83        let mut cx = numeric::cx_datum();
84        cx.set_rounding(Rounding::Floor);
85        cx.round(&mut a);
86        numeric::munge_numeric(&mut a).unwrap();
87        a
88    }
89);
90
91fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
92    if val.is_negative() {
93        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
94    }
95    if val.is_zero() {
96        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
97    }
98    Ok(())
99}
100
101// From the `decNumber` library's documentation:
102// > Inexact results will almost always be correctly rounded, but may be up to 1
103// > ulp (unit in last place) in error in rare cases.
104//
105// See decNumberLog10 documentation at http://speleotrove.com/decimal/dnnumb.html
106fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
107where
108    F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
109{
110    log_guard_numeric(&a, name)?;
111    let mut cx = numeric::cx_datum();
112    logic(&mut cx, &mut a);
113    numeric::munge_numeric(&mut a).unwrap();
114    Ok(a)
115}
116
117sqlfunc!(
118    #[sqlname = "lnnumeric"]
119    fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
120        log_numeric(a, dec::Context::ln, "ln")
121    }
122);
123
124sqlfunc!(
125    #[sqlname = "log10numeric"]
126    fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
127        log_numeric(a, dec::Context::log10, "log10")
128    }
129);
130
131sqlfunc!(
132    #[sqlname = "roundnumeric"]
133    #[is_monotone = true]
134    fn round_numeric(mut a: Numeric) -> Numeric {
135        // round will be nop if has no fractional digits.
136        if a.exponent() >= 0 {
137            return a;
138        }
139        numeric::cx_datum().round(&mut a);
140        a
141    }
142);
143
144sqlfunc!(
145    #[sqlname = "truncnumeric"]
146    #[is_monotone = true]
147    fn trunc_numeric(mut a: Numeric) -> Numeric {
148        // trunc will be nop if has no fractional digits.
149        if a.exponent() >= 0 {
150            return a;
151        }
152        let mut cx = numeric::cx_datum();
153        cx.set_rounding(Rounding::Down);
154        cx.round(&mut a);
155        numeric::munge_numeric(&mut a).unwrap();
156        a
157    }
158);
159
160sqlfunc!(
161    #[sqlname = "sqrtnumeric"]
162    fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
163        if a.is_negative() {
164            return Err(EvalError::NegSqrt);
165        }
166        let mut cx = numeric::cx_datum();
167        cx.sqrt(&mut a);
168        numeric::munge_numeric(&mut a).unwrap();
169        Ok(a)
170    }
171);
172
173sqlfunc!(
174    #[sqlname = "numeric_to_smallint"]
175    #[preserves_uniqueness = false]
176    #[inverse = to_unary!(super::CastInt16ToNumeric(None))]
177    #[is_monotone = true]
178    fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
179        let mut cx = numeric::cx_datum();
180        cx.round(&mut a);
181        cx.clear_status();
182        let i = cx
183            .try_into_i32(a)
184            .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
185        i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
186    }
187);
188
189sqlfunc!(
190    #[sqlname = "numeric_to_integer"]
191    #[preserves_uniqueness = false]
192    #[inverse = to_unary!(super::CastInt32ToNumeric(None))]
193    #[is_monotone = true]
194    fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
195        let mut cx = numeric::cx_datum();
196        cx.round(&mut a);
197        cx.clear_status();
198        cx.try_into_i32(a)
199            .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
200    }
201);
202
203sqlfunc!(
204    #[sqlname = "numeric_to_bigint"]
205    #[preserves_uniqueness = false]
206    #[inverse = to_unary!(super::CastInt64ToNumeric(None))]
207    #[is_monotone = true]
208    fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
209        let mut cx = numeric::cx_datum();
210        cx.round(&mut a);
211        cx.clear_status();
212        cx.try_into_i64(a)
213            .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
214    }
215);
216
217sqlfunc!(
218    #[sqlname = "numeric_to_real"]
219    #[preserves_uniqueness = false]
220    #[inverse = to_unary!(super::CastFloat32ToNumeric(None))]
221    #[is_monotone = true]
222    fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
223        let i = a.to_string().parse::<f32>().unwrap();
224        if i.is_infinite() {
225            Err(EvalError::Float32OutOfRange(i.to_string().into()))
226        } else {
227            Ok(i)
228        }
229    }
230);
231
232sqlfunc!(
233    #[sqlname = "numeric_to_double"]
234    #[preserves_uniqueness = false]
235    #[inverse = to_unary!(super::CastFloat64ToNumeric(None))]
236    #[is_monotone = true]
237    fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
238        let i = a.to_string().parse::<f64>().unwrap();
239        if i.is_infinite() {
240            Err(EvalError::Float64OutOfRange(i.to_string().into()))
241        } else {
242            Ok(i)
243        }
244    }
245);
246
247sqlfunc!(
248    #[sqlname = "numeric_to_text"]
249    #[preserves_uniqueness = false]
250    #[inverse = to_unary!(super::CastStringToNumeric(None))]
251    fn cast_numeric_to_string(a: Numeric) -> String {
252        let mut buf = String::new();
253        strconv::format_numeric(&mut buf, &OrderedDecimal(a));
254        buf
255    }
256);
257
258sqlfunc!(
259    #[sqlname = "numeric_to_uint2"]
260    #[preserves_uniqueness = false]
261    #[inverse = to_unary!(super::CastUint16ToNumeric(None))]
262    #[is_monotone = true]
263    fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
264        let mut cx = numeric::cx_datum();
265        cx.round(&mut a);
266        cx.clear_status();
267        let u = cx
268            .try_into_u32(a)
269            .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
270        u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
271    }
272);
273
274sqlfunc!(
275    #[sqlname = "numeric_to_uint4"]
276    #[preserves_uniqueness = false]
277    #[inverse = to_unary!(super::CastUint32ToNumeric(None))]
278    #[is_monotone = true]
279    fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
280        let mut cx = numeric::cx_datum();
281        cx.round(&mut a);
282        cx.clear_status();
283        cx.try_into_u32(a)
284            .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
285    }
286);
287
288sqlfunc!(
289    #[sqlname = "numeric_to_uint8"]
290    #[preserves_uniqueness = false]
291    #[inverse = to_unary!(super::CastUint64ToNumeric(None))]
292    #[is_monotone = true]
293    fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
294        let mut cx = numeric::cx_datum();
295        cx.round(&mut a);
296        cx.clear_status();
297        cx.try_into_u64(a)
298            .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
299    }
300);
301
302sqlfunc!(
303    #[sqlname = "pg_size_pretty"]
304    #[preserves_uniqueness = false]
305    fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
306        let mut cx = numeric::cx_datum();
307        let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
308
309        for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
310            // return if abs(round(a)) < 10 in the next unit it would be converted to.
311            if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
312                // do not round a when the unit is bytes, as no conversion has happened.
313                if pos > 0 {
314                    cx.round(&mut a);
315                }
316
317                return Ok(format!("{} {unit}", a.to_standard_notation_string()));
318            }
319
320            cx.div(&mut a, &Numeric::from(1024));
321            numeric::munge_numeric(&mut a).unwrap();
322        }
323
324        cx.round(&mut a);
325        Ok(format!(
326            "{} {}",
327            a.to_standard_notation_string(),
328            units.last().unwrap()
329        ))
330    }
331);
332
333#[derive(
334    Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect,
335)]
336pub struct AdjustNumericScale(pub NumericMaxScale);
337
338impl<'a> EagerUnaryFunc<'a> for AdjustNumericScale {
339    type Input = Numeric;
340    type Output = Result<Numeric, EvalError>;
341
342    fn call(&self, mut d: Numeric) -> Result<Numeric, EvalError> {
343        if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
344            return Err(EvalError::NumericFieldOverflow);
345        };
346        Ok(d)
347    }
348
349    fn output_type(&self, input: ColumnType) -> ColumnType {
350        ScalarType::Numeric {
351            max_scale: Some(self.0),
352        }
353        .nullable(input.nullable)
354    }
355}
356
357impl fmt::Display for AdjustNumericScale {
358    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
359        f.write_str("adjust_numeric_scale")
360    }
361}