mz_pgrepr/value/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::error::Error;
11use std::fmt;
12use std::sync::LazyLock;
13
14use byteorder::{NetworkEndian, ReadBytesExt};
15use bytes::{BufMut, BytesMut};
16use dec::OrderedDecimal;
17use mz_ore::cast::CastFrom;
18use mz_repr::adt::numeric::{self, Numeric as AdtNumeric, NumericAgg, cx_datum};
19use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
20
21/// A wrapper for the `repr` crate's `Decimal` type that can be serialized to
22/// and deserialized from the PostgreSQL binary format.
23#[derive(Debug)]
24pub struct Numeric(pub OrderedDecimal<AdtNumeric>);
25
26impl fmt::Display for Numeric {
27    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28        self.0.fmt(f)
29    }
30}
31
32impl From<AdtNumeric> for Numeric {
33    fn from(n: AdtNumeric) -> Numeric {
34        Numeric(OrderedDecimal(n))
35    }
36}
37
38/// `ToSql` and `FromSql` use base 10,000 units.
39const TO_FROM_SQL_BASE_POW: u32 = 4;
40
41static TO_SQL_BASER: LazyLock<AdtNumeric> =
42    LazyLock::new(|| AdtNumeric::from(10u32.pow(TO_FROM_SQL_BASE_POW)));
43static FROM_SQL_SCALER: LazyLock<AdtNumeric> =
44    LazyLock::new(|| AdtNumeric::from(TO_FROM_SQL_BASE_POW));
45
46/// The maximum number of units necessary to represent a valid [`AdtNumeric`] value.
47const UNITS_LEN: usize = 11;
48
49impl ToSql for Numeric {
50    fn to_sql(
51        &self,
52        _: &Type,
53        out: &mut BytesMut,
54    ) -> Result<IsNull, Box<dyn Error + 'static + Send + Sync>> {
55        let mut d = self.0.0.clone();
56        let scale = u16::try_from(numeric::get_scale(&d))?;
57        let is_zero = d.is_zero();
58        let is_nan = d.is_nan();
59        let is_neg = d.is_negative() && !is_zero;
60        let is_infinite = d.is_infinite();
61
62        let mut cx = numeric::cx_datum();
63        // Need to extend exponents slightly because fractional components need
64        // to be aligned to base 10,000.
65        cx.set_max_exponent(cx.max_exponent() + isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
66            .unwrap();
67        cx.set_min_exponent(cx.min_exponent() - isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
68            .unwrap();
69        cx.abs(&mut d);
70
71        let mut digits = [0u16; UNITS_LEN];
72        let mut d_i = UNITS_LEN;
73
74        let (fract_units, leading_zero_units) = if d.exponent() < 0 {
75            let pos_exp = usize::try_from(-d.exponent()).expect("positive value < 40");
76            // You have leading zeroes in the case where:
77            // - The exponent's absolute value exceeds the number of digits
78            // - `d` only contains fractional zeroes
79            let leading_zero_units = if pos_exp >= usize::cast_from(d.digits()) {
80                // If the value is zero, one zero digit gets double counted
81                // (this is also why the above inequality is not strict)
82                let digits = if d.is_zero() {
83                    0
84                } else {
85                    usize::cast_from(d.digits())
86                };
87                // integer division with rounding up instead of down
88                (pos_exp - digits + usize::cast_from(TO_FROM_SQL_BASE_POW) - 1)
89                    / usize::cast_from(TO_FROM_SQL_BASE_POW)
90            } else {
91                0
92            };
93
94            // Ensure most significant fractional digit in ten's place of base
95            // 10,000 value.
96            let s = pos_exp % usize::cast_from(TO_FROM_SQL_BASE_POW);
97            let unit_shift_exp = if s != 0 {
98                pos_exp + usize::cast_from(TO_FROM_SQL_BASE_POW) - s
99            } else {
100                pos_exp
101            };
102
103            // Convert d into a "canonical coefficient" with most significant
104            // fractional digit properly aligned.
105            cx.scaleb(&mut d, &AdtNumeric::from(unit_shift_exp));
106
107            (
108                u16::try_from(unit_shift_exp / usize::cast_from(TO_FROM_SQL_BASE_POW))
109                    .expect("value < 40"),
110                leading_zero_units,
111            )
112        } else {
113            (0, 0)
114        };
115
116        let mut w = d.clone();
117        while !d.is_zero() && !d.is_special() {
118            d_i -= 1;
119            // Get unit value, i.e. d % 10,000
120            cx.rem(&mut d, &TO_SQL_BASER);
121            // Decimal library doesn't support direct u16 conversion.
122            digits[d_i] =
123                u16::try_from(u32::try_from(d).expect("value < 10,000")).expect("value < 10,000");
124            cx.div_integer(&mut w, &TO_SQL_BASER);
125            d = w;
126        }
127        d_i -= leading_zero_units;
128
129        let units = u16::try_from(UNITS_LEN - d_i).unwrap();
130        let weight = if is_zero {
131            0
132        } else {
133            i16::try_from(units - fract_units).unwrap() - 1
134        };
135
136        out.put_u16(units);
137        out.put_i16(weight);
138        // sign
139        out.put_u16(if is_infinite {
140            if is_neg { 0xF000 } else { 0xD000 }
141        } else if is_nan {
142            0xC000
143        } else if is_neg {
144            0x4000
145        } else {
146            0
147        });
148        out.put_u16(scale);
149        if !is_nan {
150            for digit in digits[d_i..].iter() {
151                out.put_u16(*digit);
152            }
153        }
154
155        Ok(IsNull::No)
156    }
157
158    fn accepts(ty: &Type) -> bool {
159        matches!(*ty, Type::NUMERIC)
160    }
161
162    to_sql_checked!();
163}
164
165impl<'a> FromSql<'a> for Numeric {
166    fn from_sql(_: &Type, mut raw: &'a [u8]) -> Result<Numeric, Box<dyn Error + Sync + Send>> {
167        let units = raw.read_i16::<NetworkEndian>()?;
168        let weight = raw.read_i16::<NetworkEndian>()?;
169        let sign = raw.read_u16::<NetworkEndian>()?;
170        let in_scale = raw.read_i16::<NetworkEndian>()?;
171        let mut digits = vec![];
172        for _ in 0..units {
173            digits.push(raw.read_u16::<NetworkEndian>()?)
174        }
175
176        // We need wider context because decoding values can require >39 digits
177        // of precision given how alignment works.
178        let mut cx = numeric::cx_agg();
179        let mut d = NumericAgg::zero();
180
181        let units_usize =
182            usize::try_from(units).map_err(|_| "units must not be negative: {units}")?;
183
184        for digit in digits[..units_usize].iter() {
185            cx.scaleb(&mut d, &FROM_SQL_SCALER);
186            let n = AdtNumeric::from(u32::from(*digit));
187            cx.add(&mut d, &n);
188        }
189
190        match sign {
191            0 => (),
192            // Infinity
193            0xD000 => return Ok(Numeric::from(AdtNumeric::infinity())),
194            // -Infinity
195            0xF000 => {
196                let mut cx = numeric::cx_datum();
197                let mut d = AdtNumeric::infinity();
198                cx.neg(&mut d);
199                return Ok(Numeric::from(d));
200            }
201            // Negative
202            0x4000 => cx.neg(&mut d),
203            // NaN
204            0xC000 => return Ok(Numeric::from(AdtNumeric::nan())),
205            _ => return Err("bad sign in numeric".into()),
206        }
207
208        let mut scale = (units - weight - 1) * 4;
209
210        // Adjust scales
211        if scale < 0 {
212            // Multiply by 10^scale
213            cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale)));
214            scale = 0;
215        } else if scale > in_scale {
216            // Divide by 10^(difference in scale and in_scale)
217            cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale - in_scale)));
218            scale = in_scale;
219        }
220
221        cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale)));
222        cx.reduce(&mut d);
223
224        let mut cx = cx_datum();
225        let d_datum = cx.to_width(d);
226
227        // Reducing before taking to datum width lets us check for any status
228        // for errors.
229        if d.is_infinite() || cx.status().any() {
230            return Err(format!("Unable to take bytes to numeric value; rendered {}", d).into());
231        }
232        Ok(Numeric::from(d_datum))
233    }
234
235    fn accepts(ty: &Type) -> bool {
236        matches!(*ty, Type::NUMERIC)
237    }
238}
239
240#[mz_ore::test]
241#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decContextDefault` on OS `linux`
242fn test_to_from_sql_roundtrip() {
243    fn inner(s: &str) {
244        let mut cx = numeric::cx_datum();
245        let d = cx.parse(s).unwrap();
246        let r = Numeric(OrderedDecimal(d));
247
248        let mut out = BytesMut::new();
249
250        let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
251
252        let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
253        assert_eq!(r.0, d_from_sql.0);
254    }
255    inner("0");
256    inner("-0");
257    inner("0.1");
258    inner("0.0");
259    inner("0.00");
260    inner("0.000");
261    inner("0.0000");
262    inner("0.00000");
263    inner("123456789.012346789");
264    inner("000000000000000000000000000000000000001");
265    inner("000000000000000000000000000000000000000");
266    inner("999999999999999999999999999999999999999");
267    inner("123456789012345678901234567890123456789");
268    inner("-123456789012345678901234567890123456789");
269    inner(".123456789012345678901234567890123456789");
270    inner(".000000000000000000000000000000000000001");
271    inner(".000000000000000000000000000000000000000");
272    inner(".999999999999999999999999999999999999999");
273    inner("-0.123456789012345678901234567890123456789");
274    inner("1e25");
275    inner("-1e25");
276    inner("9.876e-25");
277    inner("-9.876e-25");
278    inner("98760000");
279    inner(".00009876");
280    inner("-.00009876");
281    inner("NaN");
282
283    // Test infinity, which is a valid value in aggregations over numeric
284    let mut cx = numeric::cx_datum();
285    let v = [
286        cx.parse("-999999999999999999999999999999999999999")
287            .unwrap(),
288        cx.parse("-999999999999999999999999999999999999999")
289            .unwrap(),
290    ];
291    // -Infinity
292    let s = cx.sum(v.iter());
293    assert!(s.is_infinite());
294    let r = Numeric::from(s);
295    let mut out = BytesMut::new();
296
297    let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
298
299    let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
300    assert_eq!(r.0, d_from_sql.0);
301}