Skip to main content

mz_expr/scalar/func/impls/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_expr_derive::sqlfunc;
14use mz_lowertest::MzReflect;
15use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
16use mz_repr::{SqlColumnType, SqlScalarType, strconv};
17use serde::{Deserialize, Serialize};
18
19use crate::EvalError;
20use crate::scalar::func::EagerUnaryFunc;
21
22#[sqlfunc(
23    sqlname = "-",
24    preserves_uniqueness = true,
25    inverse = to_unary!(NegNumeric),
26    is_monotone = true
27)]
28fn neg_numeric(mut a: Numeric) -> Numeric {
29    numeric::cx_datum().neg(&mut a);
30    numeric::munge_numeric(&mut a).unwrap();
31    a
32}
33
34#[sqlfunc(sqlname = "abs")]
35fn abs_numeric(mut a: Numeric) -> Numeric {
36    numeric::cx_datum().abs(&mut a);
37    a
38}
39
40#[sqlfunc(sqlname = "ceilnumeric", is_monotone = true)]
41fn ceil_numeric(mut a: Numeric) -> Numeric {
42    // ceil will be nop if has no fractional digits.
43    if a.exponent() >= 0 {
44        return a;
45    }
46    let mut cx = numeric::cx_datum();
47    cx.set_rounding(Rounding::Ceiling);
48    cx.round(&mut a);
49    numeric::munge_numeric(&mut a).unwrap();
50    a
51}
52
53#[sqlfunc(sqlname = "expnumeric")]
54fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
55    let mut cx = numeric::cx_datum();
56    cx.exp(&mut a);
57    let cx_status = cx.status();
58    if cx_status.overflow() {
59        Err(EvalError::FloatOverflow)
60    } else if cx_status.subnormal() {
61        Err(EvalError::FloatUnderflow)
62    } else {
63        numeric::munge_numeric(&mut a).unwrap();
64        Ok(a)
65    }
66}
67
68#[sqlfunc(sqlname = "floornumeric", is_monotone = true)]
69fn floor_numeric(mut a: Numeric) -> Numeric {
70    // floor will be nop if has no fractional digits.
71    if a.exponent() >= 0 {
72        return a;
73    }
74    let mut cx = numeric::cx_datum();
75    cx.set_rounding(Rounding::Floor);
76    cx.round(&mut a);
77    numeric::munge_numeric(&mut a).unwrap();
78    a
79}
80
81fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
82    if val.is_negative() {
83        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
84    }
85    if val.is_zero() {
86        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
87    }
88    Ok(())
89}
90
91// From the `decNumber` library's documentation:
92// > Inexact results will almost always be correctly rounded, but may be up to 1
93// > ulp (unit in last place) in error in rare cases.
94//
95// See decNumberLog10 documentation at http://speleotrove.com/decimal/dnnumb.html
96fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
97where
98    F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
99{
100    log_guard_numeric(&a, name)?;
101    let mut cx = numeric::cx_datum();
102    logic(&mut cx, &mut a);
103    numeric::munge_numeric(&mut a).unwrap();
104    Ok(a)
105}
106
107#[sqlfunc(sqlname = "lnnumeric")]
108fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
109    log_numeric(a, dec::Context::ln, "ln")
110}
111
112#[sqlfunc(sqlname = "log10numeric")]
113fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
114    log_numeric(a, dec::Context::log10, "log10")
115}
116
117#[sqlfunc(sqlname = "roundnumeric", is_monotone = true)]
118fn round_numeric(mut a: Numeric) -> Numeric {
119    // round will be nop if has no fractional digits.
120    if a.exponent() >= 0 {
121        return a;
122    }
123    numeric::cx_datum().round(&mut a);
124    // Canonicalize: `dec`'s round preserves the sign on zero results, so e.g.
125    // `round(-0.4)` yields `-0`. munge_numeric strips that, ensuring row
126    // encodings match decimal equality.
127    numeric::munge_numeric(&mut a).unwrap();
128    a
129}
130
131#[sqlfunc(sqlname = "truncnumeric", is_monotone = true)]
132fn trunc_numeric(mut a: Numeric) -> Numeric {
133    // trunc will be nop if has no fractional digits.
134    if a.exponent() >= 0 {
135        return a;
136    }
137    let mut cx = numeric::cx_datum();
138    cx.set_rounding(Rounding::Down);
139    cx.round(&mut a);
140    numeric::munge_numeric(&mut a).unwrap();
141    a
142}
143
144#[sqlfunc(sqlname = "sqrtnumeric")]
145fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
146    if a.is_negative() {
147        return Err(EvalError::NegSqrt);
148    }
149    let mut cx = numeric::cx_datum();
150    cx.sqrt(&mut a);
151    numeric::munge_numeric(&mut a).unwrap();
152    Ok(a)
153}
154
155#[sqlfunc(
156    sqlname = "numeric_to_smallint",
157    preserves_uniqueness = false,
158    inverse = to_unary!(super::CastInt16ToNumeric(None)),
159    is_monotone = true
160)]
161pub fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
162    let mut cx = numeric::cx_datum();
163    cx.round(&mut a);
164    cx.clear_status();
165    let i = cx
166        .try_into_i32(a)
167        .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
168    i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
169}
170
171#[sqlfunc(
172    sqlname = "numeric_to_integer",
173    preserves_uniqueness = false,
174    inverse = to_unary!(super::CastInt32ToNumeric(None)),
175    is_monotone = true
176)]
177pub fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
178    let mut cx = numeric::cx_datum();
179    cx.round(&mut a);
180    cx.clear_status();
181    cx.try_into_i32(a)
182        .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
183}
184
185#[sqlfunc(
186    sqlname = "numeric_to_bigint",
187    preserves_uniqueness = false,
188    inverse = to_unary!(super::CastInt64ToNumeric(None)),
189    is_monotone = true
190)]
191pub fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
192    let mut cx = numeric::cx_datum();
193    cx.round(&mut a);
194    cx.clear_status();
195    cx.try_into_i64(a)
196        .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
197}
198
199#[sqlfunc(
200    sqlname = "numeric_to_real",
201    preserves_uniqueness = false,
202    inverse = to_unary!(super::CastFloat32ToNumeric(None)),
203    is_monotone = true
204)]
205pub fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
206    let i = a.to_string().parse::<f32>().unwrap();
207    if i.is_infinite() {
208        Err(EvalError::Float32OutOfRange(i.to_string().into()))
209    } else {
210        Ok(i)
211    }
212}
213
214#[sqlfunc(
215    sqlname = "numeric_to_double",
216    preserves_uniqueness = false,
217    inverse = to_unary!(super::CastFloat64ToNumeric(None)),
218    is_monotone = true
219)]
220pub fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
221    let i = a.to_string().parse::<f64>().unwrap();
222    if i.is_infinite() {
223        Err(EvalError::Float64OutOfRange(i.to_string().into()))
224    } else {
225        Ok(i)
226    }
227}
228
229#[sqlfunc(
230    sqlname = "numeric_to_text",
231    preserves_uniqueness = false,
232    inverse = to_unary!(super::CastStringToNumeric(None))
233)]
234fn cast_numeric_to_string(a: Numeric) -> String {
235    let mut buf = String::new();
236    strconv::format_numeric(&mut buf, &OrderedDecimal(a));
237    buf
238}
239
240#[sqlfunc(
241    sqlname = "numeric_to_uint2",
242    preserves_uniqueness = false,
243    inverse = to_unary!(super::CastUint16ToNumeric(None)),
244    is_monotone = true
245)]
246fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
247    let mut cx = numeric::cx_datum();
248    cx.round(&mut a);
249    cx.clear_status();
250    let u = cx
251        .try_into_u32(a)
252        .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
253    u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
254}
255
256#[sqlfunc(
257    sqlname = "numeric_to_uint4",
258    preserves_uniqueness = false,
259    inverse = to_unary!(super::CastUint32ToNumeric(None)),
260    is_monotone = true
261)]
262fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
263    let mut cx = numeric::cx_datum();
264    cx.round(&mut a);
265    cx.clear_status();
266    cx.try_into_u32(a)
267        .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
268}
269
270#[sqlfunc(
271    sqlname = "numeric_to_uint8",
272    preserves_uniqueness = false,
273    inverse = to_unary!(super::CastUint64ToNumeric(None)),
274    is_monotone = true
275)]
276fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
277    let mut cx = numeric::cx_datum();
278    cx.round(&mut a);
279    cx.clear_status();
280    cx.try_into_u64(a)
281        .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
282}
283
284#[sqlfunc(sqlname = "pg_size_pretty", preserves_uniqueness = false)]
285fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
286    let mut cx = numeric::cx_datum();
287    let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
288
289    for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
290        // return if abs(round(a)) < 10 in the next unit it would be converted to.
291        if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
292            // do not round a when the unit is bytes, as no conversion has happened.
293            if pos > 0 {
294                cx.round(&mut a);
295            }
296
297            return Ok(format!("{} {unit}", a.to_standard_notation_string()));
298        }
299
300        cx.div(&mut a, &Numeric::from(1024));
301        numeric::munge_numeric(&mut a).unwrap();
302    }
303
304    cx.round(&mut a);
305    Ok(format!(
306        "{} {}",
307        a.to_standard_notation_string(),
308        units.last().unwrap()
309    ))
310}
311
312#[derive(
313    Ord,
314    PartialOrd,
315    Clone,
316    Debug,
317    Eq,
318    PartialEq,
319    Serialize,
320    Deserialize,
321    Hash,
322    MzReflect
323)]
324pub struct AdjustNumericScale(pub NumericMaxScale);
325
326impl EagerUnaryFunc for AdjustNumericScale {
327    type Input<'a> = Numeric;
328    type Output<'a> = Result<Numeric, EvalError>;
329
330    fn call<'a>(&self, mut d: Self::Input<'a>) -> Self::Output<'a> {
331        if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
332            return Err(EvalError::NumericFieldOverflow);
333        };
334        Ok(d)
335    }
336
337    fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType {
338        SqlScalarType::Numeric {
339            max_scale: Some(self.0),
340        }
341        .nullable(input.nullable)
342    }
343}
344
345impl fmt::Display for AdjustNumericScale {
346    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
347        f.write_str("adjust_numeric_scale")
348    }
349}