mz_expr/scalar/func/impls/
numeric.rs
1use std::fmt;
11
12use dec::{OrderedDecimal, Rounding};
13use mz_lowertest::MzReflect;
14use mz_repr::adt::numeric::{self, Numeric, NumericMaxScale};
15use mz_repr::{ColumnType, ScalarType, strconv};
16use proptest_derive::Arbitrary;
17use serde::{Deserialize, Serialize};
18
19use crate::EvalError;
20use crate::scalar::func::EagerUnaryFunc;
21
22sqlfunc!(
23 #[sqlname = "-"]
24 #[preserves_uniqueness = true]
25 #[inverse = to_unary!(NegNumeric)]
26 #[is_monotone = true]
27 fn neg_numeric(mut a: Numeric) -> Numeric {
28 numeric::cx_datum().neg(&mut a);
29 numeric::munge_numeric(&mut a).unwrap();
30 a
31 }
32);
33
34sqlfunc!(
35 #[sqlname = "abs"]
36 fn abs_numeric(mut a: Numeric) -> Numeric {
37 numeric::cx_datum().abs(&mut a);
38 a
39 }
40);
41
42sqlfunc!(
43 #[sqlname = "ceilnumeric"]
44 #[is_monotone = true]
45 fn ceil_numeric(mut a: Numeric) -> Numeric {
46 if a.exponent() >= 0 {
48 return a;
49 }
50 let mut cx = numeric::cx_datum();
51 cx.set_rounding(Rounding::Ceiling);
52 cx.round(&mut a);
53 numeric::munge_numeric(&mut a).unwrap();
54 a
55 }
56);
57
58sqlfunc!(
59 #[sqlname = "expnumeric"]
60 fn exp_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
61 let mut cx = numeric::cx_datum();
62 cx.exp(&mut a);
63 let cx_status = cx.status();
64 if cx_status.overflow() {
65 Err(EvalError::FloatOverflow)
66 } else if cx_status.subnormal() {
67 Err(EvalError::FloatUnderflow)
68 } else {
69 numeric::munge_numeric(&mut a).unwrap();
70 Ok(a)
71 }
72 }
73);
74
75sqlfunc!(
76 #[sqlname = "floornumeric"]
77 #[is_monotone = true]
78 fn floor_numeric(mut a: Numeric) -> Numeric {
79 if a.exponent() >= 0 {
81 return a;
82 }
83 let mut cx = numeric::cx_datum();
84 cx.set_rounding(Rounding::Floor);
85 cx.round(&mut a);
86 numeric::munge_numeric(&mut a).unwrap();
87 a
88 }
89);
90
91fn log_guard_numeric(val: &Numeric, function_name: &str) -> Result<(), EvalError> {
92 if val.is_negative() {
93 return Err(EvalError::NegativeOutOfDomain(function_name.into()));
94 }
95 if val.is_zero() {
96 return Err(EvalError::ZeroOutOfDomain(function_name.into()));
97 }
98 Ok(())
99}
100
101fn log_numeric<F>(mut a: Numeric, logic: F, name: &'static str) -> Result<Numeric, EvalError>
107where
108 F: Fn(&mut dec::Context<Numeric>, &mut Numeric),
109{
110 log_guard_numeric(&a, name)?;
111 let mut cx = numeric::cx_datum();
112 logic(&mut cx, &mut a);
113 numeric::munge_numeric(&mut a).unwrap();
114 Ok(a)
115}
116
117sqlfunc!(
118 #[sqlname = "lnnumeric"]
119 fn ln_numeric(a: Numeric) -> Result<Numeric, EvalError> {
120 log_numeric(a, dec::Context::ln, "ln")
121 }
122);
123
124sqlfunc!(
125 #[sqlname = "log10numeric"]
126 fn log10_numeric(a: Numeric) -> Result<Numeric, EvalError> {
127 log_numeric(a, dec::Context::log10, "log10")
128 }
129);
130
131sqlfunc!(
132 #[sqlname = "roundnumeric"]
133 #[is_monotone = true]
134 fn round_numeric(mut a: Numeric) -> Numeric {
135 if a.exponent() >= 0 {
137 return a;
138 }
139 numeric::cx_datum().round(&mut a);
140 a
141 }
142);
143
144sqlfunc!(
145 #[sqlname = "truncnumeric"]
146 #[is_monotone = true]
147 fn trunc_numeric(mut a: Numeric) -> Numeric {
148 if a.exponent() >= 0 {
150 return a;
151 }
152 let mut cx = numeric::cx_datum();
153 cx.set_rounding(Rounding::Down);
154 cx.round(&mut a);
155 numeric::munge_numeric(&mut a).unwrap();
156 a
157 }
158);
159
160sqlfunc!(
161 #[sqlname = "sqrtnumeric"]
162 fn sqrt_numeric(mut a: Numeric) -> Result<Numeric, EvalError> {
163 if a.is_negative() {
164 return Err(EvalError::NegSqrt);
165 }
166 let mut cx = numeric::cx_datum();
167 cx.sqrt(&mut a);
168 numeric::munge_numeric(&mut a).unwrap();
169 Ok(a)
170 }
171);
172
173sqlfunc!(
174 #[sqlname = "numeric_to_smallint"]
175 #[preserves_uniqueness = false]
176 #[inverse = to_unary!(super::CastInt16ToNumeric(None))]
177 #[is_monotone = true]
178 fn cast_numeric_to_int16(mut a: Numeric) -> Result<i16, EvalError> {
179 let mut cx = numeric::cx_datum();
180 cx.round(&mut a);
181 cx.clear_status();
182 let i = cx
183 .try_into_i32(a)
184 .or_else(|_| Err(EvalError::Int16OutOfRange(a.to_string().into())))?;
185 i16::try_from(i).or_else(|_| Err(EvalError::Int16OutOfRange(i.to_string().into())))
186 }
187);
188
189sqlfunc!(
190 #[sqlname = "numeric_to_integer"]
191 #[preserves_uniqueness = false]
192 #[inverse = to_unary!(super::CastInt32ToNumeric(None))]
193 #[is_monotone = true]
194 fn cast_numeric_to_int32(mut a: Numeric) -> Result<i32, EvalError> {
195 let mut cx = numeric::cx_datum();
196 cx.round(&mut a);
197 cx.clear_status();
198 cx.try_into_i32(a)
199 .or_else(|_| Err(EvalError::Int32OutOfRange(a.to_string().into())))
200 }
201);
202
203sqlfunc!(
204 #[sqlname = "numeric_to_bigint"]
205 #[preserves_uniqueness = false]
206 #[inverse = to_unary!(super::CastInt64ToNumeric(None))]
207 #[is_monotone = true]
208 fn cast_numeric_to_int64(mut a: Numeric) -> Result<i64, EvalError> {
209 let mut cx = numeric::cx_datum();
210 cx.round(&mut a);
211 cx.clear_status();
212 cx.try_into_i64(a)
213 .or_else(|_| Err(EvalError::Int64OutOfRange(a.to_string().into())))
214 }
215);
216
217sqlfunc!(
218 #[sqlname = "numeric_to_real"]
219 #[preserves_uniqueness = false]
220 #[inverse = to_unary!(super::CastFloat32ToNumeric(None))]
221 #[is_monotone = true]
222 fn cast_numeric_to_float32(a: Numeric) -> Result<f32, EvalError> {
223 let i = a.to_string().parse::<f32>().unwrap();
224 if i.is_infinite() {
225 Err(EvalError::Float32OutOfRange(i.to_string().into()))
226 } else {
227 Ok(i)
228 }
229 }
230);
231
232sqlfunc!(
233 #[sqlname = "numeric_to_double"]
234 #[preserves_uniqueness = false]
235 #[inverse = to_unary!(super::CastFloat64ToNumeric(None))]
236 #[is_monotone = true]
237 fn cast_numeric_to_float64(a: Numeric) -> Result<f64, EvalError> {
238 let i = a.to_string().parse::<f64>().unwrap();
239 if i.is_infinite() {
240 Err(EvalError::Float64OutOfRange(i.to_string().into()))
241 } else {
242 Ok(i)
243 }
244 }
245);
246
247sqlfunc!(
248 #[sqlname = "numeric_to_text"]
249 #[preserves_uniqueness = false]
250 #[inverse = to_unary!(super::CastStringToNumeric(None))]
251 fn cast_numeric_to_string(a: Numeric) -> String {
252 let mut buf = String::new();
253 strconv::format_numeric(&mut buf, &OrderedDecimal(a));
254 buf
255 }
256);
257
258sqlfunc!(
259 #[sqlname = "numeric_to_uint2"]
260 #[preserves_uniqueness = false]
261 #[inverse = to_unary!(super::CastUint16ToNumeric(None))]
262 #[is_monotone = true]
263 fn cast_numeric_to_uint16(mut a: Numeric) -> Result<u16, EvalError> {
264 let mut cx = numeric::cx_datum();
265 cx.round(&mut a);
266 cx.clear_status();
267 let u = cx
268 .try_into_u32(a)
269 .or_else(|_| Err(EvalError::UInt16OutOfRange(a.to_string().into())))?;
270 u16::try_from(u).or_else(|_| Err(EvalError::UInt16OutOfRange(u.to_string().into())))
271 }
272);
273
274sqlfunc!(
275 #[sqlname = "numeric_to_uint4"]
276 #[preserves_uniqueness = false]
277 #[inverse = to_unary!(super::CastUint32ToNumeric(None))]
278 #[is_monotone = true]
279 fn cast_numeric_to_uint32(mut a: Numeric) -> Result<u32, EvalError> {
280 let mut cx = numeric::cx_datum();
281 cx.round(&mut a);
282 cx.clear_status();
283 cx.try_into_u32(a)
284 .or_else(|_| Err(EvalError::UInt32OutOfRange(a.to_string().into())))
285 }
286);
287
288sqlfunc!(
289 #[sqlname = "numeric_to_uint8"]
290 #[preserves_uniqueness = false]
291 #[inverse = to_unary!(super::CastUint64ToNumeric(None))]
292 #[is_monotone = true]
293 fn cast_numeric_to_uint64(mut a: Numeric) -> Result<u64, EvalError> {
294 let mut cx = numeric::cx_datum();
295 cx.round(&mut a);
296 cx.clear_status();
297 cx.try_into_u64(a)
298 .or_else(|_| Err(EvalError::UInt64OutOfRange(a.to_string().into())))
299 }
300);
301
302sqlfunc!(
303 #[sqlname = "pg_size_pretty"]
304 #[preserves_uniqueness = false]
305 fn pg_size_pretty(mut a: Numeric) -> Result<String, EvalError> {
306 let mut cx = numeric::cx_datum();
307 let units = ["bytes", "kB", "MB", "GB", "TB", "PB"];
308
309 for (pos, unit) in units.iter().rev().skip(1).rev().enumerate() {
310 if Numeric::from(-10239.5) < a && a < Numeric::from(10239.5) {
312 if pos > 0 {
314 cx.round(&mut a);
315 }
316
317 return Ok(format!("{} {unit}", a.to_standard_notation_string()));
318 }
319
320 cx.div(&mut a, &Numeric::from(1024));
321 numeric::munge_numeric(&mut a).unwrap();
322 }
323
324 cx.round(&mut a);
325 Ok(format!(
326 "{} {}",
327 a.to_standard_notation_string(),
328 units.last().unwrap()
329 ))
330 }
331);
332
333#[derive(
334 Arbitrary, Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash, MzReflect,
335)]
336pub struct AdjustNumericScale(pub NumericMaxScale);
337
338impl<'a> EagerUnaryFunc<'a> for AdjustNumericScale {
339 type Input = Numeric;
340 type Output = Result<Numeric, EvalError>;
341
342 fn call(&self, mut d: Numeric) -> Result<Numeric, EvalError> {
343 if numeric::rescale(&mut d, self.0.into_u8()).is_err() {
344 return Err(EvalError::NumericFieldOverflow);
345 };
346 Ok(d)
347 }
348
349 fn output_type(&self, input: ColumnType) -> ColumnType {
350 ScalarType::Numeric {
351 max_scale: Some(self.0),
352 }
353 .nullable(input.nullable)
354 }
355}
356
357impl fmt::Display for AdjustNumericScale {
358 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
359 f.write_str("adjust_numeric_scale")
360 }
361}