mz_expr/scalar/func/impls/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_expr_derive::sqlfunc;
14use mz_lowertest::MzReflect;
15use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
16use mz_repr::{SqlColumnType, SqlScalarType, strconv};
17use serde::{Deserialize, Serialize};
18
19use crate::EvalError;
20use crate::scalar::func::EagerUnaryFunc;
21
22#[sqlfunc(
23    sqlname = "-",
24    preserves_uniqueness = true,
25    inverse = to_unary!(NegNumeric),
26    is_monotone = true
27)]
28fn neg_numeric(mut a: Numeric) -> Numeric {
29    numeric::cx_datum().neg(&mut a);
30    numeric::munge_numeric(&mut a).unwrap();
31    a
32}
33
34#[sqlfunc(sqlname = "abs")]
35fn abs_numeric(mut a: Numeric) -> Numeric {
36    numeric::cx_datum().abs(&mut a);
37    a
38}
39
40#[sqlfunc(sqlname = "ceilnumeric", is_monotone = true)]
41fn ceil_numeric(mut a: Numeric) -> Numeric {
42    // ceil will be nop if has no fractional digits.
43    if a.exponent() >= 0 {
44        return a;
45    }
46    let mut cx = numeric::cx_datum();
47    cx.set_rounding(Rounding::Ceiling);
48    cx.round(&mut a);
49    numeric::munge_numeric(&mut a).unwrap();
50    a
51}
52
53#[sqlfunc(sqlname = "expnumeric")]
54fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
55    let mut cx = numeric::cx_datum();
56    cx.exp(&mut a);
57    let cx_status = cx.status();
58    if cx_status.overflow() {
59        Err(EvalError::FloatOverflow)
60    } else if cx_status.subnormal() {
61        Err(EvalError::FloatUnderflow)
62    } else {
63        numeric::munge_numeric(&mut a).unwrap();
64        Ok(a)
65    }
66}
67
68#[sqlfunc(sqlname = "floornumeric", is_monotone = true)]
69fn floor_numeric(mut a: Numeric) -> Numeric {
70    // floor will be nop if has no fractional digits.
71    if a.exponent() >= 0 {
72        return a;
73    }
74    let mut cx = numeric::cx_datum();
75    cx.set_rounding(Rounding::Floor);
76    cx.round(&mut a);
77    numeric::munge_numeric(&mut a).unwrap();
78    a
79}
80
81fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
82    if val.is_negative() {
83        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
84    }
85    if val.is_zero() {
86        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
87    }
88    Ok(())
89}
90
91// From the `decNumber` library's documentation:
92// > Inexact results will almost always be correctly rounded, but may be up to 1
93// > ulp (unit in last place) in error in rare cases.
94//
95// See decNumberLog10 documentation at http://speleotrove.com/decimal/dnnumb.html
96fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
97where
98    F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
99{
100    log_guard_numeric(&a, name)?;
101    let mut cx = numeric::cx_datum();
102    logic(&mut cx, &mut a);
103    numeric::munge_numeric(&mut a).unwrap();
104    Ok(a)
105}
106
107#[sqlfunc(sqlname = "lnnumeric")]
108fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
109    log_numeric(a, dec::Context::ln, "ln")
110}
111
112#[sqlfunc(sqlname = "log10numeric")]
113fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
114    log_numeric(a, dec::Context::log10, "log10")
115}
116
117#[sqlfunc(sqlname = "roundnumeric", is_monotone = true)]
118fn round_numeric(mut a: Numeric) -> Numeric {
119    // round will be nop if has no fractional digits.
120    if a.exponent() >= 0 {
121        return a;
122    }
123    numeric::cx_datum().round(&mut a);
124    a
125}
126
127#[sqlfunc(sqlname = "truncnumeric", is_monotone = true)]
128fn trunc_numeric(mut a: Numeric) -> Numeric {
129    // trunc will be nop if has no fractional digits.
130    if a.exponent() >= 0 {
131        return a;
132    }
133    let mut cx = numeric::cx_datum();
134    cx.set_rounding(Rounding::Down);
135    cx.round(&mut a);
136    numeric::munge_numeric(&mut a).unwrap();
137    a
138}
139
140#[sqlfunc(sqlname = "sqrtnumeric")]
141fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
142    if a.is_negative() {
143        return Err(EvalError::NegSqrt);
144    }
145    let mut cx = numeric::cx_datum();
146    cx.sqrt(&mut a);
147    numeric::munge_numeric(&mut a).unwrap();
148    Ok(a)
149}
150
151#[sqlfunc(
152    sqlname = "numeric_to_smallint",
153    preserves_uniqueness = false,
154    inverse = to_unary!(super::CastInt16ToNumeric(None)),
155    is_monotone = true
156)]
157pub fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
158    let mut cx = numeric::cx_datum();
159    cx.round(&mut a);
160    cx.clear_status();
161    let i = cx
162        .try_into_i32(a)
163        .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
164    i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
165}
166
167#[sqlfunc(
168    sqlname = "numeric_to_integer",
169    preserves_uniqueness = false,
170    inverse = to_unary!(super::CastInt32ToNumeric(None)),
171    is_monotone = true
172)]
173pub fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
174    let mut cx = numeric::cx_datum();
175    cx.round(&mut a);
176    cx.clear_status();
177    cx.try_into_i32(a)
178        .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
179}
180
181#[sqlfunc(
182    sqlname = "numeric_to_bigint",
183    preserves_uniqueness = false,
184    inverse = to_unary!(super::CastInt64ToNumeric(None)),
185    is_monotone = true
186)]
187pub fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
188    let mut cx = numeric::cx_datum();
189    cx.round(&mut a);
190    cx.clear_status();
191    cx.try_into_i64(a)
192        .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
193}
194
195#[sqlfunc(
196    sqlname = "numeric_to_real",
197    preserves_uniqueness = false,
198    inverse = to_unary!(super::CastFloat32ToNumeric(None)),
199    is_monotone = true
200)]
201pub fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
202    let i = a.to_string().parse::<f32>().unwrap();
203    if i.is_infinite() {
204        Err(EvalError::Float32OutOfRange(i.to_string().into()))
205    } else {
206        Ok(i)
207    }
208}
209
210#[sqlfunc(
211    sqlname = "numeric_to_double",
212    preserves_uniqueness = false,
213    inverse = to_unary!(super::CastFloat64ToNumeric(None)),
214    is_monotone = true
215)]
216pub fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
217    let i = a.to_string().parse::<f64>().unwrap();
218    if i.is_infinite() {
219        Err(EvalError::Float64OutOfRange(i.to_string().into()))
220    } else {
221        Ok(i)
222    }
223}
224
225#[sqlfunc(
226    sqlname = "numeric_to_text",
227    preserves_uniqueness = false,
228    inverse = to_unary!(super::CastStringToNumeric(None))
229)]
230fn cast_numeric_to_string(a: Numeric) -> String {
231    let mut buf = String::new();
232    strconv::format_numeric(&mut buf, &OrderedDecimal(a));
233    buf
234}
235
236#[sqlfunc(
237    sqlname = "numeric_to_uint2",
238    preserves_uniqueness = false,
239    inverse = to_unary!(super::CastUint16ToNumeric(None)),
240    is_monotone = true
241)]
242fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
243    let mut cx = numeric::cx_datum();
244    cx.round(&mut a);
245    cx.clear_status();
246    let u = cx
247        .try_into_u32(a)
248        .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
249    u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
250}
251
252#[sqlfunc(
253    sqlname = "numeric_to_uint4",
254    preserves_uniqueness = false,
255    inverse = to_unary!(super::CastUint32ToNumeric(None)),
256    is_monotone = true
257)]
258fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
259    let mut cx = numeric::cx_datum();
260    cx.round(&mut a);
261    cx.clear_status();
262    cx.try_into_u32(a)
263        .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
264}
265
266#[sqlfunc(
267    sqlname = "numeric_to_uint8",
268    preserves_uniqueness = false,
269    inverse = to_unary!(super::CastUint64ToNumeric(None)),
270    is_monotone = true
271)]
272fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
273    let mut cx = numeric::cx_datum();
274    cx.round(&mut a);
275    cx.clear_status();
276    cx.try_into_u64(a)
277        .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
278}
279
280#[sqlfunc(sqlname = "pg_size_pretty", preserves_uniqueness = false)]
281fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
282    let mut cx = numeric::cx_datum();
283    let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
284
285    for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
286        // return if abs(round(a)) < 10 in the next unit it would be converted to.
287        if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
288            // do not round a when the unit is bytes, as no conversion has happened.
289            if pos > 0 {
290                cx.round(&mut a);
291            }
292
293            return Ok(format!("{} {unit}", a.to_standard_notation_string()));
294        }
295
296        cx.div(&mut a, &Numeric::from(1024));
297        numeric::munge_numeric(&mut a).unwrap();
298    }
299
300    cx.round(&mut a);
301    Ok(format!(
302        "{} {}",
303        a.to_standard_notation_string(),
304        units.last().unwrap()
305    ))
306}
307
308#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
309pub struct AdjustNumericScale(pub NumericMaxScale);
310
311impl<'a> EagerUnaryFunc<'a> for AdjustNumericScale {
312    type Input = Numeric;
313    type Output = Result<Numeric, EvalError>;
314
315    fn call(&self, mut d: Numeric) -> Result<Numeric, EvalError> {
316        if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
317            return Err(EvalError::NumericFieldOverflow);
318        };
319        Ok(d)
320    }
321
322    fn output_type(&self, input: SqlColumnType) -> SqlColumnType {
323        SqlScalarType::Numeric {
324            max_scale: Some(self.0),
325        }
326        .nullable(input.nullable)
327    }
328}
329
330impl fmt::Display for AdjustNumericScale {
331    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
332        f.write_str("adjust_numeric_scale")
333    }
334}