mz_expr/scalar/func/impls/
numeric.rs1use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_expr_derive::sqlfunc;
14use mz_lowertest::MzReflect;
15use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
16use mz_repr::{SqlColumnType, SqlScalarType, strconv};
17use serde::{Deserialize, Serialize};
18
19use crate::EvalError;
20use crate::scalar::func::EagerUnaryFunc;
21
22#[sqlfunc(
23 sqlname = "-",
24 preserves_uniqueness = true,
25 inverse = to_unary!(NegNumeric),
26 is_monotone = true
27)]
28fn neg_numeric(mut a: Numeric) -> Numeric {
29 numeric::cx_datum().neg(&mut a);
30 numeric::munge_numeric(&mut a).unwrap();
31 a
32}
33
34#[sqlfunc(sqlname = "abs")]
35fn abs_numeric(mut a: Numeric) -> Numeric {
36 numeric::cx_datum().abs(&mut a);
37 a
38}
39
40#[sqlfunc(sqlname = "ceilnumeric", is_monotone = true)]
41fn ceil_numeric(mut a: Numeric) -> Numeric {
42 if a.exponent() >= 0 {
44 return a;
45 }
46 let mut cx = numeric::cx_datum();
47 cx.set_rounding(Rounding::Ceiling);
48 cx.round(&mut a);
49 numeric::munge_numeric(&mut a).unwrap();
50 a
51}
52
53#[sqlfunc(sqlname = "expnumeric")]
54fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
55 let mut cx = numeric::cx_datum();
56 cx.exp(&mut a);
57 let cx_status = cx.status();
58 if cx_status.overflow() {
59 Err(EvalError::FloatOverflow)
60 } else if cx_status.subnormal() {
61 Err(EvalError::FloatUnderflow)
62 } else {
63 numeric::munge_numeric(&mut a).unwrap();
64 Ok(a)
65 }
66}
67
68#[sqlfunc(sqlname = "floornumeric", is_monotone = true)]
69fn floor_numeric(mut a: Numeric) -> Numeric {
70 if a.exponent() >= 0 {
72 return a;
73 }
74 let mut cx = numeric::cx_datum();
75 cx.set_rounding(Rounding::Floor);
76 cx.round(&mut a);
77 numeric::munge_numeric(&mut a).unwrap();
78 a
79}
80
81fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
82 if val.is_negative() {
83 return Err(EvalError::NegativeOutOfDomain(function_name.into()));
84 }
85 if val.is_zero() {
86 return Err(EvalError::ZeroOutOfDomain(function_name.into()));
87 }
88 Ok(())
89}
90
91fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
97where
98 F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
99{
100 log_guard_numeric(&a, name)?;
101 let mut cx = numeric::cx_datum();
102 logic(&mut cx, &mut a);
103 numeric::munge_numeric(&mut a).unwrap();
104 Ok(a)
105}
106
107#[sqlfunc(sqlname = "lnnumeric")]
108fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
109 log_numeric(a, dec::Context::ln, "ln")
110}
111
112#[sqlfunc(sqlname = "log10numeric")]
113fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
114 log_numeric(a, dec::Context::log10, "log10")
115}
116
117#[sqlfunc(sqlname = "roundnumeric", is_monotone = true)]
118fn round_numeric(mut a: Numeric) -> Numeric {
119 if a.exponent() >= 0 {
121 return a;
122 }
123 numeric::cx_datum().round(&mut a);
124 numeric::munge_numeric(&mut a).unwrap();
128 a
129}
130
131#[sqlfunc(sqlname = "truncnumeric", is_monotone = true)]
132fn trunc_numeric(mut a: Numeric) -> Numeric {
133 if a.exponent() >= 0 {
135 return a;
136 }
137 let mut cx = numeric::cx_datum();
138 cx.set_rounding(Rounding::Down);
139 cx.round(&mut a);
140 numeric::munge_numeric(&mut a).unwrap();
141 a
142}
143
144#[sqlfunc(sqlname = "sqrtnumeric")]
145fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
146 if a.is_negative() {
147 return Err(EvalError::NegSqrt);
148 }
149 let mut cx = numeric::cx_datum();
150 cx.sqrt(&mut a);
151 numeric::munge_numeric(&mut a).unwrap();
152 Ok(a)
153}
154
155#[sqlfunc(
156 sqlname = "numeric_to_smallint",
157 preserves_uniqueness = false,
158 inverse = to_unary!(super::CastInt16ToNumeric(None)),
159 is_monotone = true
160)]
161pub fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
162 let mut cx = numeric::cx_datum();
163 cx.round(&mut a);
164 cx.clear_status();
165 let i = cx
166 .try_into_i32(a)
167 .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
168 i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
169}
170
171#[sqlfunc(
172 sqlname = "numeric_to_integer",
173 preserves_uniqueness = false,
174 inverse = to_unary!(super::CastInt32ToNumeric(None)),
175 is_monotone = true
176)]
177pub fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
178 let mut cx = numeric::cx_datum();
179 cx.round(&mut a);
180 cx.clear_status();
181 cx.try_into_i32(a)
182 .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
183}
184
185#[sqlfunc(
186 sqlname = "numeric_to_bigint",
187 preserves_uniqueness = false,
188 inverse = to_unary!(super::CastInt64ToNumeric(None)),
189 is_monotone = true
190)]
191pub fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
192 let mut cx = numeric::cx_datum();
193 cx.round(&mut a);
194 cx.clear_status();
195 cx.try_into_i64(a)
196 .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
197}
198
199#[sqlfunc(
200 sqlname = "numeric_to_real",
201 preserves_uniqueness = false,
202 inverse = to_unary!(super::CastFloat32ToNumeric(None)),
203 is_monotone = true
204)]
205pub fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
206 let i = a.to_string().parse::<f32>().unwrap();
207 if i.is_infinite() {
208 Err(EvalError::Float32OutOfRange(i.to_string().into()))
209 } else {
210 Ok(i)
211 }
212}
213
214#[sqlfunc(
215 sqlname = "numeric_to_double",
216 preserves_uniqueness = false,
217 inverse = to_unary!(super::CastFloat64ToNumeric(None)),
218 is_monotone = true
219)]
220pub fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
221 let i = a.to_string().parse::<f64>().unwrap();
222 if i.is_infinite() {
223 Err(EvalError::Float64OutOfRange(i.to_string().into()))
224 } else {
225 Ok(i)
226 }
227}
228
229#[sqlfunc(
230 sqlname = "numeric_to_text",
231 preserves_uniqueness = false,
232 inverse = to_unary!(super::CastStringToNumeric(None))
233)]
234fn cast_numeric_to_string(a: Numeric) -> String {
235 let mut buf = String::new();
236 strconv::format_numeric(&mut buf, &OrderedDecimal(a));
237 buf
238}
239
240#[sqlfunc(
241 sqlname = "numeric_to_uint2",
242 preserves_uniqueness = false,
243 inverse = to_unary!(super::CastUint16ToNumeric(None)),
244 is_monotone = true
245)]
246fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
247 let mut cx = numeric::cx_datum();
248 cx.round(&mut a);
249 cx.clear_status();
250 let u = cx
251 .try_into_u32(a)
252 .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
253 u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
254}
255
256#[sqlfunc(
257 sqlname = "numeric_to_uint4",
258 preserves_uniqueness = false,
259 inverse = to_unary!(super::CastUint32ToNumeric(None)),
260 is_monotone = true
261)]
262fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
263 let mut cx = numeric::cx_datum();
264 cx.round(&mut a);
265 cx.clear_status();
266 cx.try_into_u32(a)
267 .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
268}
269
270#[sqlfunc(
271 sqlname = "numeric_to_uint8",
272 preserves_uniqueness = false,
273 inverse = to_unary!(super::CastUint64ToNumeric(None)),
274 is_monotone = true
275)]
276fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
277 let mut cx = numeric::cx_datum();
278 cx.round(&mut a);
279 cx.clear_status();
280 cx.try_into_u64(a)
281 .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
282}
283
284#[sqlfunc(sqlname = "pg_size_pretty", preserves_uniqueness = false)]
285fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
286 let mut cx = numeric::cx_datum();
287 let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
288
289 for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
290 if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
292 if pos > 0 {
294 cx.round(&mut a);
295 }
296
297 return Ok(format!("{} {unit}", a.to_standard_notation_string()));
298 }
299
300 cx.div(&mut a, &Numeric::from(1024));
301 numeric::munge_numeric(&mut a).unwrap();
302 }
303
304 cx.round(&mut a);
305 Ok(format!(
306 "{} {}",
307 a.to_standard_notation_string(),
308 units.last().unwrap()
309 ))
310}
311
312#[derive(
313 Ord,
314 PartialOrd,
315 Clone,
316 Debug,
317 Eq,
318 PartialEq,
319 Serialize,
320 Deserialize,
321 Hash,
322 MzReflect
323)]
324pub struct AdjustNumericScale(pub NumericMaxScale);
325
326impl EagerUnaryFunc for AdjustNumericScale {
327 type Input<'a> = Numeric;
328 type Output<'a> = Result<Numeric, EvalError>;
329
330 fn call<'a>(&self, mut d: Self::Input<'a>) -> Self::Output<'a> {
331 if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
332 return Err(EvalError::NumericFieldOverflow);
333 };
334 Ok(d)
335 }
336
337 fn output_sql_type(&self, input: SqlColumnType) -> SqlColumnType {
338 SqlScalarType::Numeric {
339 max_scale: Some(self.0),
340 }
341 .nullable(input.nullable)
342 }
343}
344
345impl fmt::Display for AdjustNumericScale {
346 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
347 f.write_str("adjust_numeric_scale")
348 }
349}