mz_expr/scalar/func/impls/
numeric.rs1use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_expr_derive::sqlfunc;
14use mz_lowertest::MzReflect;
15use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
16use mz_repr::{SqlColumnType, SqlScalarType, strconv};
17use serde::{Deserialize, Serialize};
18
19use crate::EvalError;
20use crate::scalar::func::EagerUnaryFunc;
21
22#[sqlfunc(
23 sqlname = "-",
24 preserves_uniqueness = true,
25 inverse = to_unary!(NegNumeric),
26 is_monotone = true
27)]
28fn neg_numeric(mut a: Numeric) -> Numeric {
29 numeric::cx_datum().neg(&mut a);
30 numeric::munge_numeric(&mut a).unwrap();
31 a
32}
33
34#[sqlfunc(sqlname = "abs")]
35fn abs_numeric(mut a: Numeric) -> Numeric {
36 numeric::cx_datum().abs(&mut a);
37 a
38}
39
40#[sqlfunc(sqlname = "ceilnumeric", is_monotone = true)]
41fn ceil_numeric(mut a: Numeric) -> Numeric {
42 if a.exponent() >= 0 {
44 return a;
45 }
46 let mut cx = numeric::cx_datum();
47 cx.set_rounding(Rounding::Ceiling);
48 cx.round(&mut a);
49 numeric::munge_numeric(&mut a).unwrap();
50 a
51}
52
53#[sqlfunc(sqlname = "expnumeric")]
54fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
55 let mut cx = numeric::cx_datum();
56 cx.exp(&mut a);
57 let cx_status = cx.status();
58 if cx_status.overflow() {
59 Err(EvalError::FloatOverflow)
60 } else if cx_status.subnormal() {
61 Err(EvalError::FloatUnderflow)
62 } else {
63 numeric::munge_numeric(&mut a).unwrap();
64 Ok(a)
65 }
66}
67
68#[sqlfunc(sqlname = "floornumeric", is_monotone = true)]
69fn floor_numeric(mut a: Numeric) -> Numeric {
70 if a.exponent() >= 0 {
72 return a;
73 }
74 let mut cx = numeric::cx_datum();
75 cx.set_rounding(Rounding::Floor);
76 cx.round(&mut a);
77 numeric::munge_numeric(&mut a).unwrap();
78 a
79}
80
81fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
82 if val.is_negative() {
83 return Err(EvalError::NegativeOutOfDomain(function_name.into()));
84 }
85 if val.is_zero() {
86 return Err(EvalError::ZeroOutOfDomain(function_name.into()));
87 }
88 Ok(())
89}
90
91fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
97where
98 F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
99{
100 log_guard_numeric(&a, name)?;
101 let mut cx = numeric::cx_datum();
102 logic(&mut cx, &mut a);
103 numeric::munge_numeric(&mut a).unwrap();
104 Ok(a)
105}
106
107#[sqlfunc(sqlname = "lnnumeric")]
108fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
109 log_numeric(a, dec::Context::ln, "ln")
110}
111
112#[sqlfunc(sqlname = "log10numeric")]
113fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
114 log_numeric(a, dec::Context::log10, "log10")
115}
116
117#[sqlfunc(sqlname = "roundnumeric", is_monotone = true)]
118fn round_numeric(mut a: Numeric) -> Numeric {
119 if a.exponent() >= 0 {
121 return a;
122 }
123 numeric::cx_datum().round(&mut a);
124 a
125}
126
127#[sqlfunc(sqlname = "truncnumeric", is_monotone = true)]
128fn trunc_numeric(mut a: Numeric) -> Numeric {
129 if a.exponent() >= 0 {
131 return a;
132 }
133 let mut cx = numeric::cx_datum();
134 cx.set_rounding(Rounding::Down);
135 cx.round(&mut a);
136 numeric::munge_numeric(&mut a).unwrap();
137 a
138}
139
140#[sqlfunc(sqlname = "sqrtnumeric")]
141fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
142 if a.is_negative() {
143 return Err(EvalError::NegSqrt);
144 }
145 let mut cx = numeric::cx_datum();
146 cx.sqrt(&mut a);
147 numeric::munge_numeric(&mut a).unwrap();
148 Ok(a)
149}
150
151#[sqlfunc(
152 sqlname = "numeric_to_smallint",
153 preserves_uniqueness = false,
154 inverse = to_unary!(super::CastInt16ToNumeric(None)),
155 is_monotone = true
156)]
157pub fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
158 let mut cx = numeric::cx_datum();
159 cx.round(&mut a);
160 cx.clear_status();
161 let i = cx
162 .try_into_i32(a)
163 .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
164 i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
165}
166
167#[sqlfunc(
168 sqlname = "numeric_to_integer",
169 preserves_uniqueness = false,
170 inverse = to_unary!(super::CastInt32ToNumeric(None)),
171 is_monotone = true
172)]
173pub fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
174 let mut cx = numeric::cx_datum();
175 cx.round(&mut a);
176 cx.clear_status();
177 cx.try_into_i32(a)
178 .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
179}
180
181#[sqlfunc(
182 sqlname = "numeric_to_bigint",
183 preserves_uniqueness = false,
184 inverse = to_unary!(super::CastInt64ToNumeric(None)),
185 is_monotone = true
186)]
187pub fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
188 let mut cx = numeric::cx_datum();
189 cx.round(&mut a);
190 cx.clear_status();
191 cx.try_into_i64(a)
192 .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
193}
194
195#[sqlfunc(
196 sqlname = "numeric_to_real",
197 preserves_uniqueness = false,
198 inverse = to_unary!(super::CastFloat32ToNumeric(None)),
199 is_monotone = true
200)]
201pub fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
202 let i = a.to_string().parse::<f32>().unwrap();
203 if i.is_infinite() {
204 Err(EvalError::Float32OutOfRange(i.to_string().into()))
205 } else {
206 Ok(i)
207 }
208}
209
210#[sqlfunc(
211 sqlname = "numeric_to_double",
212 preserves_uniqueness = false,
213 inverse = to_unary!(super::CastFloat64ToNumeric(None)),
214 is_monotone = true
215)]
216pub fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
217 let i = a.to_string().parse::<f64>().unwrap();
218 if i.is_infinite() {
219 Err(EvalError::Float64OutOfRange(i.to_string().into()))
220 } else {
221 Ok(i)
222 }
223}
224
225#[sqlfunc(
226 sqlname = "numeric_to_text",
227 preserves_uniqueness = false,
228 inverse = to_unary!(super::CastStringToNumeric(None))
229)]
230fn cast_numeric_to_string(a: Numeric) -> String {
231 let mut buf = String::new();
232 strconv::format_numeric(&mut buf, &OrderedDecimal(a));
233 buf
234}
235
236#[sqlfunc(
237 sqlname = "numeric_to_uint2",
238 preserves_uniqueness = false,
239 inverse = to_unary!(super::CastUint16ToNumeric(None)),
240 is_monotone = true
241)]
242fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
243 let mut cx = numeric::cx_datum();
244 cx.round(&mut a);
245 cx.clear_status();
246 let u = cx
247 .try_into_u32(a)
248 .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
249 u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
250}
251
252#[sqlfunc(
253 sqlname = "numeric_to_uint4",
254 preserves_uniqueness = false,
255 inverse = to_unary!(super::CastUint32ToNumeric(None)),
256 is_monotone = true
257)]
258fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
259 let mut cx = numeric::cx_datum();
260 cx.round(&mut a);
261 cx.clear_status();
262 cx.try_into_u32(a)
263 .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
264}
265
266#[sqlfunc(
267 sqlname = "numeric_to_uint8",
268 preserves_uniqueness = false,
269 inverse = to_unary!(super::CastUint64ToNumeric(None)),
270 is_monotone = true
271)]
272fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
273 let mut cx = numeric::cx_datum();
274 cx.round(&mut a);
275 cx.clear_status();
276 cx.try_into_u64(a)
277 .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
278}
279
280#[sqlfunc(sqlname = "pg_size_pretty", preserves_uniqueness = false)]
281fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
282 let mut cx = numeric::cx_datum();
283 let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
284
285 for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
286 if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
288 if pos > 0 {
290 cx.round(&mut a);
291 }
292
293 return Ok(format!("{} {unit}", a.to_standard_notation_string()));
294 }
295
296 cx.div(&mut a, &Numeric::from(1024));
297 numeric::munge_numeric(&mut a).unwrap();
298 }
299
300 cx.round(&mut a);
301 Ok(format!(
302 "{} {}",
303 a.to_standard_notation_string(),
304 units.last().unwrap()
305 ))
306}
307
308#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect)]
309pub struct AdjustNumericScale(pub NumericMaxScale);
310
311impl<'a> EagerUnaryFunc<'a> for AdjustNumericScale {
312 type Input = Numeric;
313 type Output = Result<Numeric, EvalError>;
314
315 fn call(&self, mut d: Numeric) -> Result<Numeric, EvalError> {
316 if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
317 return Err(EvalError::NumericFieldOverflow);
318 };
319 Ok(d)
320 }
321
322 fn output_type(&self, input: SqlColumnType) -> SqlColumnType {
323 SqlScalarType::Numeric {
324 max_scale: Some(self.0),
325 }
326 .nullable(input.nullable)
327 }
328}
329
330impl fmt::Display for AdjustNumericScale {
331 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
332 f.write_str("adjust_numeric_scale")
333 }
334}