Skip to main content

mz_pgrepr/value/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12use std::sync::LazyLock;
13
14use byteorder::{NetworkEndian, ReadBytesExt};
15use bytes::{BufMut, BytesMut};
16use dec::OrderedDecimal;
17use mz_ore::cast::CastFrom;
18use mz_repr::adt::numeric::{self, Numeric as AdtNumeric, NumericAgg, cx_datum};
19use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
20
21/// A wrapper for the `repr` crate's `Decimal` type that can be serialized to
22/// and deserialized from the PostgreSQL binary format.
23#[derive(Debug)]
24pub struct Numeric(pub OrderedDecimal<AdtNumeric>);
25
26impl fmt::Display for Numeric {
27    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28        self.0.fmt(f)
29    }
30}
31
32impl From<AdtNumeric> for Numeric {
33    fn from(n: AdtNumeric) -> Numeric {
34        Numeric(OrderedDecimal(n))
35    }
36}
37
38/// `ToSql` and `FromSql` use base 10,000 units.
39const TO_FROM_SQL_BASE_POW: u32 = 4;
40
41static TO_SQL_BASER: LazyLock<AdtNumeric> =
42    LazyLock::new(|| AdtNumeric::from(10u32.pow(TO_FROM_SQL_BASE_POW)));
43static FROM_SQL_SCALER: LazyLock<AdtNumeric> =
44    LazyLock::new(|| AdtNumeric::from(TO_FROM_SQL_BASE_POW));
45
46/// The maximum number of units necessary to represent a valid [`AdtNumeric`] value.
47const UNITS_LEN: usize = 11;
48
49impl ToSql for Numeric {
50    fn to_sql(
51        &self,
52        _: &Type,
53        out: &mut BytesMut,
54    ) -> Result<IsNull, Box<dyn Error + 'static + Send + Sync>> {
55        let mut d = self.0.0.clone();
56        let scale = u16::try_from(numeric::get_scale(&d))?;
57        let is_zero = d.is_zero();
58        let is_nan = d.is_nan();
59        let is_neg = d.is_negative() && !is_zero;
60        let is_infinite = d.is_infinite();
61
62        let mut cx = numeric::cx_datum();
63        // Need to extend exponents slightly because fractional components need
64        // to be aligned to base 10,000.
65        cx.set_max_exponent(cx.max_exponent() + isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
66            .unwrap();
67        cx.set_min_exponent(cx.min_exponent() - isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
68            .unwrap();
69        cx.abs(&mut d);
70
71        let mut digits = [0u16; UNITS_LEN];
72        let mut d_i = UNITS_LEN;
73
74        let (fract_units, leading_zero_units) = if d.exponent() < 0 {
75            let pos_exp = usize::try_from(-d.exponent()).expect("positive value < 40");
76            // You have leading zeroes in the case where:
77            // - The exponent's absolute value exceeds the number of digits
78            // - `d` only contains fractional zeroes
79            let leading_zero_units = if pos_exp >= usize::cast_from(d.digits()) {
80                // If the value is zero, one zero digit gets double counted
81                // (this is also why the above inequality is not strict)
82                let digits = if d.is_zero() {
83                    0
84                } else {
85                    usize::cast_from(d.digits())
86                };
87                // integer division with rounding up instead of down
88                (pos_exp - digits + usize::cast_from(TO_FROM_SQL_BASE_POW) - 1)
89                    / usize::cast_from(TO_FROM_SQL_BASE_POW)
90            } else {
91                0
92            };
93
94            // Ensure most significant fractional digit in ten's place of base
95            // 10,000 value.
96            let s = pos_exp % usize::cast_from(TO_FROM_SQL_BASE_POW);
97            let unit_shift_exp = if s != 0 {
98                pos_exp + usize::cast_from(TO_FROM_SQL_BASE_POW) - s
99            } else {
100                pos_exp
101            };
102
103            // Convert d into a "canonical coefficient" with most significant
104            // fractional digit properly aligned.
105            cx.scaleb(&mut d, &AdtNumeric::from(unit_shift_exp));
106
107            (
108                u16::try_from(unit_shift_exp / usize::cast_from(TO_FROM_SQL_BASE_POW))
109                    .expect("value < 40"),
110                leading_zero_units,
111            )
112        } else {
113            (0, 0)
114        };
115
116        let mut w = d.clone();
117        while !d.is_zero() && !d.is_special() {
118            d_i -= 1;
119            // Get unit value, i.e. d % 10,000
120            cx.rem(&mut d, &TO_SQL_BASER);
121            // Decimal library doesn't support direct u16 conversion.
122            digits[d_i] =
123                u16::try_from(u32::try_from(d).expect("value < 10,000")).expect("value < 10,000");
124            cx.div_integer(&mut w, &TO_SQL_BASER);
125            d = w;
126        }
127        d_i -= leading_zero_units;
128
129        let units = u16::try_from(UNITS_LEN - d_i).unwrap();
130        let weight = if is_zero {
131            0
132        } else {
133            i16::try_from(units - fract_units).unwrap() - 1
134        };
135
136        out.put_u16(units);
137        out.put_i16(weight);
138        // sign
139        out.put_u16(if is_infinite {
140            if is_neg { 0xF000 } else { 0xD000 }
141        } else if is_nan {
142            0xC000
143        } else if is_neg {
144            0x4000
145        } else {
146            0
147        });
148        out.put_u16(scale);
149        if !is_nan {
150            for digit in digits[d_i..].iter() {
151                out.put_u16(*digit);
152            }
153        }
154
155        Ok(IsNull::No)
156    }
157
158    fn accepts(ty: &Type) -> bool {
159        matches!(*ty, Type::NUMERIC)
160    }
161
162    to_sql_checked!();
163}
164
165impl<'a> FromSql<'a> for Numeric {
166    fn from_sql(_: &Type, mut raw: &'a [u8]) -> Result<Numeric, Box<dyn Error + Sync + Send>> {
167        let units = raw.read_i16::<NetworkEndian>()?;
168        let weight = raw.read_i16::<NetworkEndian>()?;
169        let sign = raw.read_u16::<NetworkEndian>()?;
170        let in_scale = raw.read_i16::<NetworkEndian>()?;
171        let mut digits = vec![];
172        for _ in 0..units {
173            digits.push(raw.read_u16::<NetworkEndian>()?)
174        }
175
176        // We need wider context because decoding values can require >39 digits
177        // of precision given how alignment works.
178        let mut cx = numeric::cx_agg();
179        let mut d = NumericAgg::zero();
180
181        let units_usize =
182            usize::try_from(units).map_err(|_| "units must not be negative: {units}")?;
183
184        for digit in digits[..units_usize].iter() {
185            cx.scaleb(&mut d, &FROM_SQL_SCALER);
186            let n = AdtNumeric::from(u32::from(*digit));
187            cx.add(&mut d, &n);
188        }
189
190        match sign {
191            0 => (),
192            // Infinity
193            0xD000 => return Ok(Numeric::from(AdtNumeric::infinity())),
194            // -Infinity
195            0xF000 => {
196                let mut cx = numeric::cx_datum();
197                let mut d = AdtNumeric::infinity();
198                cx.neg(&mut d);
199                return Ok(Numeric::from(d));
200            }
201            // Negative
202            0x4000 => cx.neg(&mut d),
203            // NaN
204            0xC000 => return Ok(Numeric::from(AdtNumeric::nan())),
205            _ => return Err("bad sign in numeric".into()),
206        }
207
208        // `units`, `weight`, and `in_scale` are read verbatim off the wire, so
209        // for a binary `Bind` parameter they are entirely user-controlled.
210        // Validate them and compute the scale in `i32`: done in `i16` (as this
211        // once was), a `weight` of `i16::MIN` makes `units - weight - 1` exceed
212        // `i16::MAX` and overflow, panicking the connection task.
213        if in_scale < 0 {
214            return Err(format!("invalid numeric binary value: negative dscale {in_scale}").into());
215        }
216        let base_pow = i32::try_from(TO_FROM_SQL_BASE_POW).expect("TO_FROM_SQL_BASE_POW fits i32");
217        let mut scale = (i32::from(units) - i32::from(weight) - 1) * base_pow;
218
219        // Reject headers whose implied scale is out of range rather than letting
220        // the decimal arithmetic below silently collapse the value to zero (or
221        // trip a downstream context-status error). A valid `numeric` decodes
222        // within the aggregation context's precision, so anything beyond that
223        // (e.g. the `weight = i16::MIN` value above, whose scale is ~131072) is
224        // not representable.
225        if scale.unsigned_abs() > u32::from(numeric::NUMERIC_AGG_MAX_PRECISION) {
226            return Err(format!("invalid numeric binary value: scale {scale} out of range").into());
227        }
228
229        // Adjust scales
230        if scale < 0 {
231            // Multiply by 10^scale
232            cx.scaleb(&mut d, &AdtNumeric::from(-scale));
233            scale = 0;
234        } else if scale > i32::from(in_scale) {
235            // Divide by 10^(difference in scale and in_scale)
236            cx.scaleb(&mut d, &AdtNumeric::from(-(scale - i32::from(in_scale))));
237            scale = i32::from(in_scale);
238        }
239
240        cx.scaleb(&mut d, &AdtNumeric::from(-scale));
241        cx.reduce(&mut d);
242
243        let mut cx = cx_datum();
244        let d_datum = cx.to_width(d);
245
246        // Reducing before taking to datum width lets us check for any status
247        // for errors.
248        if d.is_infinite() || cx.status().any() {
249            return Err(format!("Unable to take bytes to numeric value; rendered {}", d).into());
250        }
251        Ok(Numeric::from(d_datum))
252    }
253
254    fn accepts(ty: &Type) -> bool {
255        matches!(*ty, Type::NUMERIC)
256    }
257}
258
259#[mz_ore::test]
260#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decContextDefault` on OS `linux`
261fn test_to_from_sql_roundtrip() {
262    fn inner(s: &str) {
263        let mut cx = numeric::cx_datum();
264        let d = cx.parse(s).unwrap();
265        let r = Numeric(OrderedDecimal(d));
266
267        let mut out = BytesMut::new();
268
269        let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
270
271        let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
272        assert_eq!(r.0, d_from_sql.0);
273    }
274    inner("0");
275    inner("-0");
276    inner("0.1");
277    inner("0.0");
278    inner("0.00");
279    inner("0.000");
280    inner("0.0000");
281    inner("0.00000");
282    inner("123456789.012346789");
283    inner("000000000000000000000000000000000000001");
284    inner("000000000000000000000000000000000000000");
285    inner("999999999999999999999999999999999999999");
286    inner("123456789012345678901234567890123456789");
287    inner("-123456789012345678901234567890123456789");
288    inner(".123456789012345678901234567890123456789");
289    inner(".000000000000000000000000000000000000001");
290    inner(".000000000000000000000000000000000000000");
291    inner(".999999999999999999999999999999999999999");
292    inner("-0.123456789012345678901234567890123456789");
293    inner("1e25");
294    inner("-1e25");
295    inner("9.876e-25");
296    inner("-9.876e-25");
297    inner("98760000");
298    inner(".00009876");
299    inner("-.00009876");
300    inner("NaN");
301
302    // Test infinity, which is a valid value in aggregations over numeric
303    let mut cx = numeric::cx_datum();
304    let v = [
305        cx.parse("-999999999999999999999999999999999999999")
306            .unwrap(),
307        cx.parse("-999999999999999999999999999999999999999")
308            .unwrap(),
309    ];
310    // -Infinity
311    let s = cx.sum(v.iter());
312    assert!(s.is_infinite());
313    let r = Numeric::from(s);
314    let mut out = BytesMut::new();
315
316    let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
317
318    let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
319    assert_eq!(r.0, d_from_sql.0);
320}