mz_expr/scalar/func/impls/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_lowertest::MzReflect;
14use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
15use mz_repr::{SqlColumnType, SqlScalarType, strconv};
16use serde::{Deserialize, Serialize};
17
18use crate::EvalError;
19use crate::scalar::func::EagerUnaryFunc;
20
21sqlfunc!(
22    #[sqlname = "-"]
23    #[preserves_uniqueness = true]
24    #[inverse = to_unary!(NegNumeric)]
25    #[is_monotone = true]
26    fn neg_numeric(mut a: Numeric) -> Numeric {
27        numeric::cx_datum().neg(&mut a);
28        numeric::munge_numeric(&mut a).unwrap();
29        a
30    }
31);
32
33sqlfunc!(
34    #[sqlname = "abs"]
35    fn abs_numeric(mut a: Numeric) -> Numeric {
36        numeric::cx_datum().abs(&mut a);
37        a
38    }
39);
40
41sqlfunc!(
42    #[sqlname = "ceilnumeric"]
43    #[is_monotone = true]
44    fn ceil_numeric(mut a: Numeric) -> Numeric {
45        // ceil will be nop if has no fractional digits.
46        if a.exponent() >= 0 {
47            return a;
48        }
49        let mut cx = numeric::cx_datum();
50        cx.set_rounding(Rounding::Ceiling);
51        cx.round(&mut a);
52        numeric::munge_numeric(&mut a).unwrap();
53        a
54    }
55);
56
57sqlfunc!(
58    #[sqlname = "expnumeric"]
59    fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
60        let mut cx = numeric::cx_datum();
61        cx.exp(&mut a);
62        let cx_status = cx.status();
63        if cx_status.overflow() {
64            Err(EvalError::FloatOverflow)
65        } else if cx_status.subnormal() {
66            Err(EvalError::FloatUnderflow)
67        } else {
68            numeric::munge_numeric(&mut a).unwrap();
69            Ok(a)
70        }
71    }
72);
73
74sqlfunc!(
75    #[sqlname = "floornumeric"]
76    #[is_monotone = true]
77    fn floor_numeric(mut a: Numeric) -> Numeric {
78        // floor will be nop if has no fractional digits.
79        if a.exponent() >= 0 {
80            return a;
81        }
82        let mut cx = numeric::cx_datum();
83        cx.set_rounding(Rounding::Floor);
84        cx.round(&mut a);
85        numeric::munge_numeric(&mut a).unwrap();
86        a
87    }
88);
89
90fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
91    if val.is_negative() {
92        return Err(EvalError::NegativeOutOfDomain(function_name.into()));
93    }
94    if val.is_zero() {
95        return Err(EvalError::ZeroOutOfDomain(function_name.into()));
96    }
97    Ok(())
98}
99
100// From the `decNumber` library's documentation:
101// > Inexact results will almost always be correctly rounded, but may be up to 1
102// > ulp (unit in last place) in error in rare cases.
103//
104// See decNumberLog10 documentation at http://speleotrove.com/decimal/dnnumb.html
105fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
106where
107    F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
108{
109    log_guard_numeric(&a, name)?;
110    let mut cx = numeric::cx_datum();
111    logic(&mut cx, &mut a);
112    numeric::munge_numeric(&mut a).unwrap();
113    Ok(a)
114}
115
116sqlfunc!(
117    #[sqlname = "lnnumeric"]
118    fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
119        log_numeric(a, dec::Context::ln, "ln")
120    }
121);
122
123sqlfunc!(
124    #[sqlname = "log10numeric"]
125    fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
126        log_numeric(a, dec::Context::log10, "log10")
127    }
128);
129
130sqlfunc!(
131    #[sqlname = "roundnumeric"]
132    #[is_monotone = true]
133    fn round_numeric(mut a: Numeric) -> Numeric {
134        // round will be nop if has no fractional digits.
135        if a.exponent() >= 0 {
136            return a;
137        }
138        numeric::cx_datum().round(&mut a);
139        a
140    }
141);
142
143sqlfunc!(
144    #[sqlname = "truncnumeric"]
145    #[is_monotone = true]
146    fn trunc_numeric(mut a: Numeric) -> Numeric {
147        // trunc will be nop if has no fractional digits.
148        if a.exponent() >= 0 {
149            return a;
150        }
151        let mut cx = numeric::cx_datum();
152        cx.set_rounding(Rounding::Down);
153        cx.round(&mut a);
154        numeric::munge_numeric(&mut a).unwrap();
155        a
156    }
157);
158
159sqlfunc!(
160    #[sqlname = "sqrtnumeric"]
161    fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
162        if a.is_negative() {
163            return Err(EvalError::NegSqrt);
164        }
165        let mut cx = numeric::cx_datum();
166        cx.sqrt(&mut a);
167        numeric::munge_numeric(&mut a).unwrap();
168        Ok(a)
169    }
170);
171
172sqlfunc!(
173    #[sqlname = "numeric_to_smallint"]
174    #[preserves_uniqueness = false]
175    #[inverse = to_unary!(super::CastInt16ToNumeric(None))]
176    #[is_monotone = true]
177    fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
178        let mut cx = numeric::cx_datum();
179        cx.round(&mut a);
180        cx.clear_status();
181        let i = cx
182            .try_into_i32(a)
183            .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
184        i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
185    }
186);
187
188sqlfunc!(
189    #[sqlname = "numeric_to_integer"]
190    #[preserves_uniqueness = false]
191    #[inverse = to_unary!(super::CastInt32ToNumeric(None))]
192    #[is_monotone = true]
193    fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
194        let mut cx = numeric::cx_datum();
195        cx.round(&mut a);
196        cx.clear_status();
197        cx.try_into_i32(a)
198            .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
199    }
200);
201
202sqlfunc!(
203    #[sqlname = "numeric_to_bigint"]
204    #[preserves_uniqueness = false]
205    #[inverse = to_unary!(super::CastInt64ToNumeric(None))]
206    #[is_monotone = true]
207    fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
208        let mut cx = numeric::cx_datum();
209        cx.round(&mut a);
210        cx.clear_status();
211        cx.try_into_i64(a)
212            .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
213    }
214);
215
216sqlfunc!(
217    #[sqlname = "numeric_to_real"]
218    #[preserves_uniqueness = false]
219    #[inverse = to_unary!(super::CastFloat32ToNumeric(None))]
220    #[is_monotone = true]
221    fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
222        let i = a.to_string().parse::<f32>().unwrap();
223        if i.is_infinite() {
224            Err(EvalError::Float32OutOfRange(i.to_string().into()))
225        } else {
226            Ok(i)
227        }
228    }
229);
230
231sqlfunc!(
232    #[sqlname = "numeric_to_double"]
233    #[preserves_uniqueness = false]
234    #[inverse = to_unary!(super::CastFloat64ToNumeric(None))]
235    #[is_monotone = true]
236    fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
237        let i = a.to_string().parse::<f64>().unwrap();
238        if i.is_infinite() {
239            Err(EvalError::Float64OutOfRange(i.to_string().into()))
240        } else {
241            Ok(i)
242        }
243    }
244);
245
246sqlfunc!(
247    #[sqlname = "numeric_to_text"]
248    #[preserves_uniqueness = false]
249    #[inverse = to_unary!(super::CastStringToNumeric(None))]
250    fn cast_numeric_to_string(a: Numeric) -> String {
251        let mut buf = String::new();
252        strconv::format_numeric(&mut buf, &OrderedDecimal(a));
253        buf
254    }
255);
256
257sqlfunc!(
258    #[sqlname = "numeric_to_uint2"]
259    #[preserves_uniqueness = false]
260    #[inverse = to_unary!(super::CastUint16ToNumeric(None))]
261    #[is_monotone = true]
262    fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
263        let mut cx = numeric::cx_datum();
264        cx.round(&mut a);
265        cx.clear_status();
266        let u = cx
267            .try_into_u32(a)
268            .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
269        u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
270    }
271);
272
273sqlfunc!(
274    #[sqlname = "numeric_to_uint4"]
275    #[preserves_uniqueness = false]
276    #[inverse = to_unary!(super::CastUint32ToNumeric(None))]
277    #[is_monotone = true]
278    fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
279        let mut cx = numeric::cx_datum();
280        cx.round(&mut a);
281        cx.clear_status();
282        cx.try_into_u32(a)
283            .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
284    }
285);
286
287sqlfunc!(
288    #[sqlname = "numeric_to_uint8"]
289    #[preserves_uniqueness = false]
290    #[inverse = to_unary!(super::CastUint64ToNumeric(None))]
291    #[is_monotone = true]
292    fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
293        let mut cx = numeric::cx_datum();
294        cx.round(&mut a);
295        cx.clear_status();
296        cx.try_into_u64(a)
297            .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
298    }
299);
300
301sqlfunc!(
302    #[sqlname = "pg_size_pretty"]
303    #[preserves_uniqueness = false]
304    fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
305        let mut cx = numeric::cx_datum();
306        let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
307
308        for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
309            // return if abs(round(a)) < 10 in the next unit it would be converted to.
310            if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
311                // do not round a when the unit is bytes, as no conversion has happened.
312                if pos > 0 {
313                    cx.round(&mut a);
314                }
315
316                return Ok(format!("{} {unit}", a.to_standard_notation_string()));
317            }
318
319            cx.div(&mut a, &Numeric::from(1024));
320            numeric::munge_numeric(&mut a).unwrap();
321        }
322
323        cx.round(&mut a);
324        Ok(format!(
325            "{} {}",
326            a.to_standard_notation_string(),
327            units.last().unwrap()
328        ))
329    }
330);
331
332#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
333pub struct AdjustNumericScale(pub NumericMaxScale);
334
335impl<'a> EagerUnaryFunc<'a> for AdjustNumericScale {
336    type Input = Numeric;
337    type Output = Result<Numeric, EvalError>;
338
339    fn call(&self, mut d: Numeric) -> Result<Numeric, EvalError> {
340        if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
341            return Err(EvalError::NumericFieldOverflow);
342        };
343        Ok(d)
344    }
345
346    fn output_type(&self, input: SqlColumnType) -> SqlColumnType {
347        SqlScalarType::Numeric {
348            max_scale: Some(self.0),
349        }
350        .nullable(input.nullable)
351    }
352}
353
354impl fmt::Display for AdjustNumericScale {
355    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
356        f.write_str("adjust_numeric_scale")
357    }
358}