mz_pgrepr/value/
numeric.rs
1use std::error::Error;
11use std::fmt;
12use std::sync::LazyLock;
13
14use byteorder::{NetworkEndian, ReadBytesExt};
15use bytes::{BufMut, BytesMut};
16use dec::OrderedDecimal;
17use mz_ore::cast::CastFrom;
18use mz_repr::adt::numeric::{self, Numeric as AdtNumeric, NumericAgg, cx_datum};
19use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
20
21#[derive(Debug)]
24pub struct Numeric(pub OrderedDecimal<AdtNumeric>);
25
26impl fmt::Display for Numeric {
27 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28 self.0.fmt(f)
29 }
30}
31
32impl From<AdtNumeric> for Numeric {
33 fn from(n: AdtNumeric) -> Numeric {
34 Numeric(OrderedDecimal(n))
35 }
36}
37
38const TO_FROM_SQL_BASE_POW: u32 = 4;
40
41static TO_SQL_BASER: LazyLock<AdtNumeric> =
42 LazyLock::new(|| AdtNumeric::from(10u32.pow(TO_FROM_SQL_BASE_POW)));
43static FROM_SQL_SCALER: LazyLock<AdtNumeric> =
44 LazyLock::new(|| AdtNumeric::from(TO_FROM_SQL_BASE_POW));
45
46const UNITS_LEN: usize = 11;
48
49impl ToSql for Numeric {
50 fn to_sql(
51 &self,
52 _: &Type,
53 out: &mut BytesMut,
54 ) -> Result<IsNull, Box<dyn Error + 'static + Send + Sync>> {
55 let mut d = self.0.0.clone();
56 let scale = u16::try_from(numeric::get_scale(&d))?;
57 let is_zero = d.is_zero();
58 let is_nan = d.is_nan();
59 let is_neg = d.is_negative() && !is_zero;
60 let is_infinite = d.is_infinite();
61
62 let mut cx = numeric::cx_datum();
63 cx.set_max_exponent(cx.max_exponent() + isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
66 .unwrap();
67 cx.set_min_exponent(cx.min_exponent() - isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
68 .unwrap();
69 cx.abs(&mut d);
70
71 let mut digits = [0u16; UNITS_LEN];
72 let mut d_i = UNITS_LEN;
73
74 let (fract_units, leading_zero_units) = if d.exponent() < 0 {
75 let pos_exp = usize::try_from(-d.exponent()).expect("positive value < 40");
76 let leading_zero_units = if pos_exp >= usize::cast_from(d.digits()) {
80 let digits = if d.is_zero() {
83 0
84 } else {
85 usize::cast_from(d.digits())
86 };
87 (pos_exp - digits + usize::cast_from(TO_FROM_SQL_BASE_POW) - 1)
89 / usize::cast_from(TO_FROM_SQL_BASE_POW)
90 } else {
91 0
92 };
93
94 let s = pos_exp % usize::cast_from(TO_FROM_SQL_BASE_POW);
97 let unit_shift_exp = if s != 0 {
98 pos_exp + usize::cast_from(TO_FROM_SQL_BASE_POW) - s
99 } else {
100 pos_exp
101 };
102
103 cx.scaleb(&mut d, &AdtNumeric::from(unit_shift_exp));
106
107 (
108 u16::try_from(unit_shift_exp / usize::cast_from(TO_FROM_SQL_BASE_POW))
109 .expect("value < 40"),
110 leading_zero_units,
111 )
112 } else {
113 (0, 0)
114 };
115
116 let mut w = d.clone();
117 while !d.is_zero() && !d.is_special() {
118 d_i -= 1;
119 cx.rem(&mut d, &TO_SQL_BASER);
121 digits[d_i] =
123 u16::try_from(u32::try_from(d).expect("value < 10,000")).expect("value < 10,000");
124 cx.div_integer(&mut w, &TO_SQL_BASER);
125 d = w;
126 }
127 d_i -= leading_zero_units;
128
129 let units = u16::try_from(UNITS_LEN - d_i).unwrap();
130 let weight = if is_zero {
131 0
132 } else {
133 i16::try_from(units - fract_units).unwrap() - 1
134 };
135
136 out.put_u16(units);
137 out.put_i16(weight);
138 out.put_u16(if is_infinite {
140 if is_neg { 0xF000 } else { 0xD000 }
141 } else if is_nan {
142 0xC000
143 } else if is_neg {
144 0x4000
145 } else {
146 0
147 });
148 out.put_u16(scale);
149 if !is_nan {
150 for digit in digits[d_i..].iter() {
151 out.put_u16(*digit);
152 }
153 }
154
155 Ok(IsNull::No)
156 }
157
158 fn accepts(ty: &Type) -> bool {
159 matches!(*ty, Type::NUMERIC)
160 }
161
162 to_sql_checked!();
163}
164
165impl<'a> FromSql<'a> for Numeric {
166 fn from_sql(_: &Type, mut raw: &'a [u8]) -> Result<Numeric, Box<dyn Error + Sync + Send>> {
167 let units = raw.read_i16::<NetworkEndian>()?;
168 let weight = raw.read_i16::<NetworkEndian>()?;
169 let sign = raw.read_u16::<NetworkEndian>()?;
170 let in_scale = raw.read_i16::<NetworkEndian>()?;
171 let mut digits = vec![];
172 for _ in 0..units {
173 digits.push(raw.read_u16::<NetworkEndian>()?)
174 }
175
176 let mut cx = numeric::cx_agg();
179 let mut d = NumericAgg::zero();
180
181 let units_usize =
182 usize::try_from(units).map_err(|_| "units must not be negative: {units}")?;
183
184 for digit in digits[..units_usize].iter() {
185 cx.scaleb(&mut d, &FROM_SQL_SCALER);
186 let n = AdtNumeric::from(u32::from(*digit));
187 cx.add(&mut d, &n);
188 }
189
190 match sign {
191 0 => (),
192 0xD000 => return Ok(Numeric::from(AdtNumeric::infinity())),
194 0xF000 => {
196 let mut cx = numeric::cx_datum();
197 let mut d = AdtNumeric::infinity();
198 cx.neg(&mut d);
199 return Ok(Numeric::from(d));
200 }
201 0x4000 => cx.neg(&mut d),
203 0xC000 => return Ok(Numeric::from(AdtNumeric::nan())),
205 _ => return Err("bad sign in numeric".into()),
206 }
207
208 let mut scale = (units - weight - 1) * 4;
209
210 if scale < 0 {
212 cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale)));
214 scale = 0;
215 } else if scale > in_scale {
216 cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale - in_scale)));
218 scale = in_scale;
219 }
220
221 cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale)));
222 cx.reduce(&mut d);
223
224 let mut cx = cx_datum();
225 let d_datum = cx.to_width(d);
226
227 if d.is_infinite() || cx.status().any() {
230 return Err(format!("Unable to take bytes to numeric value; rendered {}", d).into());
231 }
232 Ok(Numeric::from(d_datum))
233 }
234
235 fn accepts(ty: &Type) -> bool {
236 matches!(*ty, Type::NUMERIC)
237 }
238}
239
240#[mz_ore::test]
241#[cfg_attr(miri, ignore)] fn test_to_from_sql_roundtrip() {
243 fn inner(s: &str) {
244 let mut cx = numeric::cx_datum();
245 let d = cx.parse(s).unwrap();
246 let r = Numeric(OrderedDecimal(d));
247
248 let mut out = BytesMut::new();
249
250 let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
251
252 let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
253 assert_eq!(r.0, d_from_sql.0);
254 }
255 inner("0");
256 inner("-0");
257 inner("0.1");
258 inner("0.0");
259 inner("0.00");
260 inner("0.000");
261 inner("0.0000");
262 inner("0.00000");
263 inner("123456789.012346789");
264 inner("000000000000000000000000000000000000001");
265 inner("000000000000000000000000000000000000000");
266 inner("999999999999999999999999999999999999999");
267 inner("123456789012345678901234567890123456789");
268 inner("-123456789012345678901234567890123456789");
269 inner(".123456789012345678901234567890123456789");
270 inner(".000000000000000000000000000000000000001");
271 inner(".000000000000000000000000000000000000000");
272 inner(".999999999999999999999999999999999999999");
273 inner("-0.123456789012345678901234567890123456789");
274 inner("1e25");
275 inner("-1e25");
276 inner("9.876e-25");
277 inner("-9.876e-25");
278 inner("98760000");
279 inner(".00009876");
280 inner("-.00009876");
281 inner("NaN");
282
283 let mut cx = numeric::cx_datum();
285 let v = [
286 cx.parse("-999999999999999999999999999999999999999")
287 .unwrap(),
288 cx.parse("-999999999999999999999999999999999999999")
289 .unwrap(),
290 ];
291 let s = cx.sum(v.iter());
293 assert!(s.is_infinite());
294 let r = Numeric::from(s);
295 let mut out = BytesMut::new();
296
297 let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
298
299 let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
300 assert_eq!(r.0, d_from_sql.0);
301}