Skip to main content

mz_repr/adt/
numeric.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Functions related to Materialize's numeric type, which is largely a wrapper
11//! around [`rust-dec`].
12//!
13//! [`rust-dec`]: https://github.com/MaterializeInc/rust-dec/
14
15use std::error::Error;
16use std::fmt;
17use std::sync::LazyLock;
18
19use anyhow::bail;
20use dec::{Context, Decimal};
21use mz_lowertest::MzReflect;
22use mz_ore::cast;
23use mz_persist_types::columnar::FixedSizeCodec;
24use mz_proto::{ProtoType, RustType, TryFromProtoError};
25#[cfg(any(test, feature = "proptest"))]
26use proptest_derive::Arbitrary;
27use serde::{Deserialize, Serialize};
28
29include!(concat!(env!("OUT_DIR"), "/mz_repr.adt.numeric.rs"));
30
31/// The number of internal decimal units in a [`Numeric`] value.
32pub const NUMERIC_DATUM_WIDTH: u8 = 13;
33
34/// The value of [`NUMERIC_DATUM_WIDTH`] as a [`u8`].
35pub const NUMERIC_DATUM_WIDTH_USIZE: usize = cast::u8_to_usize(NUMERIC_DATUM_WIDTH);
36
37/// The maximum number of digits expressable in a [`Numeric`] value.
38pub const NUMERIC_DATUM_MAX_PRECISION: u8 = NUMERIC_DATUM_WIDTH * 3;
39
40/// A numeric value.
41pub type Numeric = Decimal<NUMERIC_DATUM_WIDTH_USIZE>;
42
43/// The number of internal decimal units in a [`NumericAgg`] value.
44pub const NUMERIC_AGG_WIDTH: u8 = 27;
45
46/// The value of [`NUMERIC_AGG_WIDTH`] as a [`u8`].
47pub const NUMERIC_AGG_WIDTH_USIZE: usize = cast::u8_to_usize(NUMERIC_AGG_WIDTH);
48
49/// The maximum number of digits expressable in a [`NumericAgg`] value.
50pub const NUMERIC_AGG_MAX_PRECISION: u8 = NUMERIC_AGG_WIDTH * 3;
51
52/// A double-width version of [`Numeric`] for use in aggregations.
53pub type NumericAgg = Decimal<NUMERIC_AGG_WIDTH_USIZE>;
54
55static CX_DATUM: LazyLock<Context<Numeric>> = LazyLock::new(|| {
56    let mut cx = Context::<Numeric>::default();
57    cx.set_max_exponent(isize::from(NUMERIC_DATUM_MAX_PRECISION - 1))
58        .unwrap();
59    cx.set_min_exponent(-isize::from(NUMERIC_DATUM_MAX_PRECISION))
60        .unwrap();
61    cx
62});
63static CX_AGG: LazyLock<Context<NumericAgg>> = LazyLock::new(|| {
64    let mut cx = Context::<NumericAgg>::default();
65    cx.set_max_exponent(isize::from(NUMERIC_AGG_MAX_PRECISION - 1))
66        .unwrap();
67    cx.set_min_exponent(-isize::from(NUMERIC_AGG_MAX_PRECISION))
68        .unwrap();
69    cx
70});
71static U128_SPLITTER_DATUM: LazyLock<Numeric> = LazyLock::new(|| {
72    let mut cx = Numeric::context();
73    // 1 << 128
74    cx.parse("340282366920938463463374607431768211456").unwrap()
75});
76static U128_SPLITTER_AGG: LazyLock<NumericAgg> = LazyLock::new(|| {
77    let mut cx = NumericAgg::context();
78    // 1 << 128
79    cx.parse("340282366920938463463374607431768211456").unwrap()
80});
81
82/// Module to simplify serde'ing a `Numeric` through its string representation.
83pub mod str_serde {
84    use std::str::FromStr;
85
86    use serde::Deserialize;
87
88    use super::Numeric;
89
90    /// Deserializing a [`Numeric`] value from its `String` representation.
91    pub fn deserialize<'de, D>(deserializer: D) -> Result<Numeric, D::Error>
92    where
93        D: serde::Deserializer<'de>,
94    {
95        let buf = String::deserialize(deserializer)?;
96        Numeric::from_str(&buf).map_err(serde::de::Error::custom)
97    }
98}
99
100/// The `max_scale` of a [`SqlScalarType::Numeric`].
101///
102/// This newtype wrapper ensures that the scale is within the valid range.
103///
104/// [`SqlScalarType::Numeric`]: crate::SqlScalarType::Numeric
105#[derive(
106    Debug,
107    Clone,
108    Copy,
109    Eq,
110    PartialEq,
111    Ord,
112    PartialOrd,
113    Hash,
114    Serialize,
115    Deserialize,
116    MzReflect
117)]
118#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
119pub struct NumericMaxScale(pub(crate) u8);
120
121impl NumericMaxScale {
122    /// A max scale of zero.
123    pub const ZERO: NumericMaxScale = NumericMaxScale(0);
124
125    /// Consumes the newtype wrapper, returning the inner `u8`.
126    pub fn into_u8(self) -> u8 {
127        self.0
128    }
129}
130
131impl TryFrom<i64> for NumericMaxScale {
132    type Error = InvalidNumericMaxScaleError;
133
134    fn try_from(max_scale: i64) -> Result<Self, Self::Error> {
135        match u8::try_from(max_scale) {
136            Ok(max_scale) if max_scale <= NUMERIC_DATUM_MAX_PRECISION => {
137                Ok(NumericMaxScale(max_scale))
138            }
139            _ => Err(InvalidNumericMaxScaleError),
140        }
141    }
142}
143
144impl TryFrom<usize> for NumericMaxScale {
145    type Error = InvalidNumericMaxScaleError;
146
147    fn try_from(max_scale: usize) -> Result<Self, Self::Error> {
148        Self::try_from(i64::try_from(max_scale).map_err(|_| InvalidNumericMaxScaleError)?)
149    }
150}
151
152impl RustType<ProtoNumericMaxScale> for NumericMaxScale {
153    fn into_proto(&self) -> ProtoNumericMaxScale {
154        ProtoNumericMaxScale {
155            value: self.0.into_proto(),
156        }
157    }
158
159    fn from_proto(max_scale: ProtoNumericMaxScale) -> Result<Self, TryFromProtoError> {
160        Ok(NumericMaxScale(max_scale.value.into_rust()?))
161    }
162}
163
164impl RustType<ProtoOptionalNumericMaxScale> for Option<NumericMaxScale> {
165    fn into_proto(&self) -> ProtoOptionalNumericMaxScale {
166        ProtoOptionalNumericMaxScale {
167            value: self.into_proto(),
168        }
169    }
170
171    fn from_proto(max_scale: ProtoOptionalNumericMaxScale) -> Result<Self, TryFromProtoError> {
172        max_scale.value.into_rust()
173    }
174}
175
176/// The error returned when constructing a [`NumericMaxScale`] from an invalid
177/// value.
178#[derive(Debug, Clone)]
179pub struct InvalidNumericMaxScaleError;
180
181impl fmt::Display for InvalidNumericMaxScaleError {
182    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183        write!(
184            f,
185            "scale for type numeric must be between 0 and {}",
186            NUMERIC_DATUM_MAX_PRECISION
187        )
188    }
189}
190
191impl Error for InvalidNumericMaxScaleError {}
192
193/// Traits to generalize converting [`Decimal`] values to and from their
194/// coefficients' two's complements.
195pub trait Dec<const N: usize> {
196    // The number of bytes required to represent the min/max value of a decimal
197    // using two's complement.
198    const TWOS_COMPLEMENT_BYTE_WIDTH: usize;
199    // Convenience method for generating appropriate default contexts.
200    fn context() -> Context<Decimal<N>>;
201    // Provides value to break decimal into units of `u128`s for binary
202    // encoding/decoding.
203    fn u128_splitter() -> &'static Decimal<N>;
204}
205
206impl Dec<NUMERIC_DATUM_WIDTH_USIZE> for Numeric {
207    const TWOS_COMPLEMENT_BYTE_WIDTH: usize = 17;
208    fn context() -> Context<Numeric> {
209        CX_DATUM.clone()
210    }
211    fn u128_splitter() -> &'static Numeric {
212        &U128_SPLITTER_DATUM
213    }
214}
215
216impl Dec<NUMERIC_AGG_WIDTH_USIZE> for NumericAgg {
217    const TWOS_COMPLEMENT_BYTE_WIDTH: usize = 33;
218    fn context() -> Context<NumericAgg> {
219        CX_AGG.clone()
220    }
221    fn u128_splitter() -> &'static NumericAgg {
222        &U128_SPLITTER_AGG
223    }
224}
225
226/// Returns a new context appropriate for operating on numeric datums.
227pub fn cx_datum() -> Context<Numeric> {
228    CX_DATUM.clone()
229}
230
231/// Returns a new context appropriate for operating on numeric aggregates.
232pub fn cx_agg() -> Context<NumericAgg> {
233    CX_AGG.clone()
234}
235
236fn twos_complement_be_to_u128(input: &[u8]) -> u128 {
237    assert!(input.len() <= 16);
238    let mut buf = [0; 16];
239    buf[16 - input.len()..16].copy_from_slice(input);
240    u128::from_be_bytes(buf)
241}
242
243/// Using negative binary numbers can require more digits of precision than
244/// [`Numeric`] offers, so we need to have the option to swap bytes' signs at the
245/// byte- rather than the library-level.
246fn negate_twos_complement_le<'a, I>(b: I)
247where
248    I: Iterator<Item = &'a mut u8>,
249{
250    let mut seen_first_one = false;
251    for i in b {
252        if seen_first_one {
253            *i = *i ^ 0xFF;
254        } else if *i > 0 {
255            seen_first_one = true;
256            if i == &0x80 {
257                continue;
258            }
259            let tz = i.trailing_zeros();
260            *i = *i ^ (0xFF << tz + 1);
261        }
262    }
263}
264
265/// Converts an [`Numeric`] into its big endian two's complement representation.
266pub fn numeric_to_twos_complement_be(
267    mut numeric: Numeric,
268) -> [u8; Numeric::TWOS_COMPLEMENT_BYTE_WIDTH] {
269    let mut buf = [0; Numeric::TWOS_COMPLEMENT_BYTE_WIDTH];
270    // Avro doesn't specify how to handle NaN/infinity, so we simply treat them
271    // as zeroes so as to avoid erroring (encoding values is meant to be
272    // infallible) and retain downstream associativity/commutativity.
273    if numeric.is_special() {
274        return buf;
275    }
276
277    let mut cx = Numeric::context();
278
279    // Ensure `numeric` is a canonical coefficient.
280    if numeric.exponent() < 0 {
281        let s = Numeric::from(-numeric.exponent());
282        cx.scaleb(&mut numeric, &s);
283    }
284
285    numeric_to_twos_complement_inner::<Numeric, NUMERIC_DATUM_WIDTH_USIZE>(
286        numeric, &mut cx, &mut buf,
287    );
288    buf
289}
290
291/// Converts an [`Numeric`] into a big endian two's complement representation where
292/// the encoded value has [`NUMERIC_AGG_MAX_PRECISION`] digits and a scale of
293/// [`NUMERIC_DATUM_MAX_PRECISION`].
294///
295/// This representation is appropriate to use in
296/// contexts requiring two's complement representation but `Numeric` values' scale
297/// isn't known, e.g. when working with columns with an explicitly defined
298/// scale.
299pub fn numeric_to_twos_complement_wide(
300    numeric: Numeric,
301) -> [u8; NumericAgg::TWOS_COMPLEMENT_BYTE_WIDTH] {
302    let mut buf = [0; NumericAgg::TWOS_COMPLEMENT_BYTE_WIDTH];
303    // Avro doesn't specify how to handle NaN/infinity, so we simply treat them
304    // as zeroes so as to avoid erroring (encoding values is meant to be
305    // infallible) and retain downstream associativity/commutativity.
306    if numeric.is_special() {
307        return buf;
308    }
309    let mut cx = NumericAgg::context();
310    let mut d = cx.to_width(numeric);
311    let mut scaler = NumericAgg::from(NUMERIC_DATUM_MAX_PRECISION);
312    cx.neg(&mut scaler);
313    // Shape `d` so that its exponent is -NUMERIC_DATUM_MAX_PRECISION
314    cx.rescale(&mut d, &scaler);
315    // Adjust `d` so it is a canonical coefficient, i.e. its exact value can be
316    // recovered by setting its exponent to -39.
317    cx.abs(&mut scaler);
318    cx.scaleb(&mut d, &scaler);
319
320    numeric_to_twos_complement_inner::<NumericAgg, NUMERIC_AGG_WIDTH_USIZE>(d, &mut cx, &mut buf);
321    buf
322}
323
324fn numeric_to_twos_complement_inner<D: Dec<N>, const N: usize>(
325    mut d: Decimal<N>,
326    cx: &mut Context<Decimal<N>>,
327    buf: &mut [u8],
328) {
329    // Adjust negative values to be writable as series of `u128`.
330    let is_neg = if d.is_negative() {
331        cx.neg(&mut d);
332        true
333    } else {
334        false
335    };
336
337    // Values have all been made into canonical coefficients.
338    assert!(d.exponent() >= 0);
339
340    let mut buf_cursor = 0;
341    while !d.is_zero() {
342        let mut w = d.clone();
343        // Take the remainder; this represents one of our "units" to take the coefficient of, i.e. d & u128::MAX
344        cx.rem(&mut w, D::u128_splitter());
345
346        // Take the `u128` version of the coefficient, which will always be what
347        // we want given that we adjusted negative values to have an unsigned
348        // integer representation.
349        let c = w.coefficient::<u128>().unwrap();
350
351        // Determine the width of the coefficient we want to take, i.e. the full
352        // coefficient or a part of it to fill the buffer.
353        let e = std::cmp::min(buf_cursor + 16, D::TWOS_COMPLEMENT_BYTE_WIDTH);
354
355        // We're putting less significant bytes at index 0, which is little endian.
356        buf[buf_cursor..e].copy_from_slice(&c.to_le_bytes()[0..e - buf_cursor]);
357        // Advance cursor; ok that it will go past buffer on final + 1th iteration.
358        buf_cursor += 16;
359
360        // Take the quotient to represent the next unit, i.e. d >> 128
361        cx.div_integer(&mut d, D::u128_splitter());
362    }
363
364    if is_neg {
365        negate_twos_complement_le(buf.iter_mut());
366    }
367
368    // Convert from little endian to big endian.
369    buf.reverse();
370}
371
372pub fn twos_complement_be_to_numeric(
373    input: &mut [u8],
374    scale: u8,
375) -> Result<Numeric, anyhow::Error> {
376    let mut cx = cx_datum();
377    if input.len() <= 17 {
378        if let Ok(mut n) =
379            twos_complement_be_to_numeric_inner::<Numeric, NUMERIC_DATUM_WIDTH_USIZE>(input)
380        {
381            n.set_exponent(-i32::from(scale));
382            return Ok(n);
383        }
384    }
385    // If bytes were invalid for narrower representation, try to use wider
386    // representation in case e.g. simply has more trailing zeroes.
387    let mut n = twos_complement_be_to_numeric_inner::<NumericAgg, NUMERIC_AGG_WIDTH_USIZE>(input)?;
388    // Exponent must be set before converting to `Numeric` width, otherwise values can overflow 39 dop.
389    n.set_exponent(-i32::from(scale));
390    let d = cx.to_width(n);
391    if cx.status().inexact() {
392        bail!("Value exceeds maximum numeric value")
393    }
394    Ok(d)
395}
396
397/// Parses a buffer of two's complement digits in big-endian order and converts
398/// them to [`Decimal<N>`].
399pub fn twos_complement_be_to_numeric_inner<D: Dec<N>, const N: usize>(
400    input: &mut [u8],
401) -> Result<Decimal<N>, anyhow::Error> {
402    let is_neg = if (input[0] & 0x80) != 0 {
403        // byte-level negate all negative values, guaranteeing all bytes are
404        // readable as unsigned.
405        negate_twos_complement_le(input.iter_mut().rev());
406        true
407    } else {
408        false
409    };
410
411    let head = input.len() % 16;
412    let i = twos_complement_be_to_u128(&input[0..head]);
413    let mut cx = D::context();
414    let mut d = cx.from_u128(i);
415
416    for c in input[head..].chunks(16) {
417        assert_eq!(c.len(), 16);
418        // essentially d << 128
419        cx.mul(&mut d, D::u128_splitter());
420        let i = twos_complement_be_to_u128(c);
421        let i = cx.from_u128(i);
422        cx.add(&mut d, &i);
423    }
424
425    if cx.status().inexact() {
426        bail!("Value exceeds maximum numeric value")
427    } else if cx.status().any() {
428        bail!("unexpected status {:?}", cx.status());
429    }
430    if is_neg {
431        cx.neg(&mut d);
432    }
433    Ok(d)
434}
435
436#[mz_ore::test]
437#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
438fn test_twos_complement_roundtrip() {
439    fn inner(s: &str) {
440        let mut cx = cx_datum();
441        let d = cx.parse(s).unwrap();
442        let scale = std::cmp::min(d.exponent(), 0).abs();
443        let mut b = numeric_to_twos_complement_be(d.clone());
444        let x = twos_complement_be_to_numeric(&mut b, u8::try_from(scale).unwrap()).unwrap();
445        assert_eq!(d, x);
446    }
447    inner("0");
448    inner("0.000000000000000000000000000000000012345");
449    inner("0.123456789012345678901234567890123456789");
450    inner("1.00000000000000000000000000000000000000");
451    inner("1");
452    inner("2");
453    inner("170141183460469231731687303715884105727");
454    inner("170141183460469231731687303715884105728");
455    inner("12345678901234567890.1234567890123456789");
456    inner("999999999999999999999999999999999999999");
457    inner("7e35");
458    inner("7e-35");
459    inner("-0.000000000000000000000000000000000012345");
460    inner("-0.12345678901234567890123456789012345678");
461    inner("-1.00000000000000000000000000000000000000");
462    inner("-1");
463    inner("-2");
464    inner("-170141183460469231731687303715884105727");
465    inner("-170141183460469231731687303715884105728");
466    inner("-12345678901234567890.1234567890123456789");
467    inner("-999999999999999999999999999999999999999");
468    inner("-7.2e35");
469    inner("-7.2e-35");
470}
471
472#[mz_ore::test]
473#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
474fn test_twos_comp_numeric_primitive() {
475    fn inner_inner<P>(i: P, i_be_bytes: &mut [u8])
476    where
477        P: Into<Numeric> + TryFrom<Numeric> + Eq + PartialEq + std::fmt::Debug + Copy,
478    {
479        let mut e = [0; Numeric::TWOS_COMPLEMENT_BYTE_WIDTH];
480        e[Numeric::TWOS_COMPLEMENT_BYTE_WIDTH - i_be_bytes.len()..].copy_from_slice(i_be_bytes);
481        let mut w = [0; NumericAgg::TWOS_COMPLEMENT_BYTE_WIDTH];
482        w[NumericAgg::TWOS_COMPLEMENT_BYTE_WIDTH - i_be_bytes.len()..].copy_from_slice(i_be_bytes);
483
484        let d: Numeric = i.into();
485
486        // Extend negative sign into most-significant bits
487        if d.is_negative() {
488            for i in e[..Numeric::TWOS_COMPLEMENT_BYTE_WIDTH - i_be_bytes.len()].iter_mut() {
489                *i = 0xFF;
490            }
491            for i in w[..NumericAgg::TWOS_COMPLEMENT_BYTE_WIDTH - i_be_bytes.len()].iter_mut() {
492                *i = 0xFF;
493            }
494        }
495
496        // Ensure decimal value's two's complement representation matches an
497        // extended version of `to_be_bytes`.
498        let d_be_bytes = numeric_to_twos_complement_be(d);
499        assert_eq!(
500            e, d_be_bytes,
501            "expected repr of {:?}, got {:?}",
502            e, d_be_bytes
503        );
504
505        // Ensure extended version of `to_be_bytes` generates same `i128`.
506        let e_numeric = twos_complement_be_to_numeric(&mut e, 0).unwrap();
507        let e_p: P = e_numeric
508            .try_into()
509            .unwrap_or_else(|_e| panic!("try_into failed"));
510        assert_eq!(i, e_p, "expected val of {:?}, got {:?}", i, e_p);
511
512        // Wide representation produces same result.
513        let w_numeric = twos_complement_be_to_numeric(&mut w, 0).unwrap();
514        let w_p: P = w_numeric
515            .try_into()
516            .unwrap_or_else(|_e| panic!("try_into failed"));
517        assert_eq!(i, w_p, "expected val of {:?}, got {:?}", i, e_p);
518
519        // Bytes do not need to be in `Numeric`-specific format
520        let p_numeric = twos_complement_be_to_numeric(i_be_bytes, 0).unwrap();
521        let p_p: P = p_numeric
522            .try_into()
523            .unwrap_or_else(|_e| panic!("try_into failed"));
524        assert_eq!(i, p_p, "expected val of {:?}, got {:?}", i, p_p);
525    }
526
527    fn inner_i32(i: i32) {
528        inner_inner(i, &mut i.to_be_bytes());
529    }
530
531    fn inner_i64(i: i64) {
532        inner_inner(i, &mut i.to_be_bytes());
533    }
534
535    // We need a wrapper around i128 to implement the same traits as the other
536    // primitive types. This is less code than a second implementation of the
537    // same test that takes unwrapped i128s.
538    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
539    struct FromableI128 {
540        i: i128,
541    }
542    impl From<i128> for FromableI128 {
543        fn from(i: i128) -> FromableI128 {
544            FromableI128 { i }
545        }
546    }
547    impl From<FromableI128> for Numeric {
548        fn from(n: FromableI128) -> Numeric {
549            Numeric::try_from(n.i).unwrap()
550        }
551    }
552    impl TryFrom<Numeric> for FromableI128 {
553        type Error = ();
554        fn try_from(n: Numeric) -> Result<FromableI128, Self::Error> {
555            match i128::try_from(n) {
556                Ok(i) => Ok(FromableI128 { i }),
557                Err(_) => Err(()),
558            }
559        }
560    }
561
562    fn inner_i128(i: i128) {
563        inner_inner(FromableI128::from(i), &mut i.to_be_bytes());
564    }
565
566    inner_i32(0);
567    inner_i32(1);
568    inner_i32(2);
569    inner_i32(-1);
570    inner_i32(-2);
571    inner_i32(i32::MAX);
572    inner_i32(i32::MIN);
573    inner_i32(i32::MAX / 7 + 7);
574    inner_i32(i32::MIN / 7 + 7);
575    inner_i64(0);
576    inner_i64(1);
577    inner_i64(2);
578    inner_i64(-1);
579    inner_i64(-2);
580    inner_i64(i64::MAX);
581    inner_i64(i64::MIN);
582    inner_i64(i64::MAX / 7 + 7);
583    inner_i64(i64::MIN / 7 + 7);
584    inner_i128(0);
585    inner_i128(1);
586    inner_i128(2);
587    inner_i128(-1);
588    inner_i128(-2);
589    inner_i128(i128::from(i64::MAX));
590    inner_i128(i128::from(i64::MIN));
591    inner_i128(i128::MAX);
592    inner_i128(i128::MIN);
593    inner_i128(i128::MAX / 7 + 7);
594    inner_i128(i128::MIN / 7 + 7);
595}
596
597#[mz_ore::test]
598#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
599fn test_twos_complement_to_numeric_fail() {
600    fn inner(b: &mut [u8]) {
601        let r = twos_complement_be_to_numeric(b, 0);
602        mz_ore::assert_err!(r);
603    }
604    // 17-byte signed digit's max value exceeds 39 digits of precision
605    let mut e = [0xFF; Numeric::TWOS_COMPLEMENT_BYTE_WIDTH];
606    e[0] -= 0x80;
607    inner(&mut e);
608
609    // 1 << 17 * 8 exceeds exceeds 39 digits of precision
610    let mut e = [0; Numeric::TWOS_COMPLEMENT_BYTE_WIDTH + 1];
611    e[0] = 1;
612    inner(&mut e);
613}
614
615#[mz_ore::test]
616#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
617fn test_wide_twos_complement_roundtrip() {
618    fn inner(s: &str) {
619        let mut cx = cx_datum();
620        let d = cx.parse(s).unwrap();
621        let mut b = numeric_to_twos_complement_wide(d.clone());
622        let x = twos_complement_be_to_numeric(&mut b, NUMERIC_DATUM_MAX_PRECISION).unwrap();
623        assert_eq!(d, x);
624    }
625    inner("0");
626    inner("0.000000000000000000000000000000000012345");
627    inner("0.123456789012345678901234567890123456789");
628    inner("1.00000000000000000000000000000000000000");
629    inner("1");
630    inner("2");
631    inner("170141183460469231731687303715884105727");
632    inner("170141183460469231731687303715884105728");
633    inner("12345678901234567890.1234567890123456789");
634    inner("999999999999999999999999999999999999999");
635    inner("-0.000000000000000000000000000000000012345");
636    inner("-0.123456789012345678901234567890123456789");
637    inner("-1.00000000000000000000000000000000000000");
638    inner("-1");
639    inner("-2");
640    inner("-170141183460469231731687303715884105727");
641    inner("-170141183460469231731687303715884105728");
642    inner("-12345678901234567890.1234567890123456789");
643    inner("-999999999999999999999999999999999999999");
644}
645
646/// Returns `n`'s precision, i.e. the total number of digits represented by `n`
647/// in standard notation not including a zero in the "one's place" in (-1,1).
648pub fn get_precision<const N: usize>(n: &Decimal<N>) -> u32 {
649    let e = n.exponent();
650    if e >= 0 {
651        // Positive exponent
652        n.digits() + u32::try_from(e).unwrap()
653    } else {
654        // Negative exponent
655        let d = n.digits();
656        let e = u32::try_from(e.abs()).unwrap();
657        // Precision is...
658        // - d if decimal point splits numbers
659        // - e if e dominates number of digits
660        std::cmp::max(d, e)
661    }
662}
663
664/// Returns `n`'s scale, i.e. the number of digits used after the decimal point.
665pub fn get_scale(n: &Numeric) -> u32 {
666    let exp = n.exponent();
667    if exp >= 0 { 0 } else { exp.unsigned_abs() }
668}
669
670/// Ensures [`Numeric`] values are:
671/// - Within `Numeric`'s max precision ([`NUMERIC_DATUM_MAX_PRECISION`]), or errors if not.
672/// - Never possible but invalid representations (i.e. never -Nan or -0).
673///
674/// Should be called after any operation that can change an [`Numeric`]'s scale or
675/// generate negative values (except addition and subtraction).
676pub fn munge_numeric(n: &mut Numeric) -> Result<(), anyhow::Error> {
677    rescale_within_max_precision(n)?;
678    if (n.is_zero() || n.is_nan()) && n.is_negative() {
679        cx_datum().neg(n);
680    }
681    Ok(())
682}
683
684/// Rescale's `n` to fit within [`Numeric`]'s max precision or error if not
685/// possible.
686fn rescale_within_max_precision(n: &mut Numeric) -> Result<(), anyhow::Error> {
687    let current_precision = get_precision(n);
688    if current_precision > u32::from(NUMERIC_DATUM_MAX_PRECISION) {
689        if n.exponent() < 0 {
690            let precision_diff = current_precision - u32::from(NUMERIC_DATUM_MAX_PRECISION);
691            let current_scale = u8::try_from(get_scale(n))?;
692            let scale_diff = current_scale - u8::try_from(precision_diff).unwrap();
693            rescale(n, scale_diff)?;
694        } else {
695            bail!(
696                "numeric value {} exceed maximum precision {}",
697                n,
698                NUMERIC_DATUM_MAX_PRECISION
699            )
700        }
701    }
702    Ok(())
703}
704
705/// Rescale `n` as an `OrderedDecimal` with the described scale, or error if:
706/// - Rescaling exceeds max precision
707/// - `n` requires > [`NUMERIC_DATUM_MAX_PRECISION`] - `scale` digits of precision
708///   left of the decimal point
709pub fn rescale(n: &mut Numeric, scale: u8) -> Result<(), anyhow::Error> {
710    let mut cx = cx_datum();
711    cx.rescale(n, &Numeric::from(-i32::from(scale)));
712    if cx.status().invalid_operation() || get_precision(n) > u32::from(NUMERIC_DATUM_MAX_PRECISION)
713    {
714        bail!(
715            "numeric value {} exceed maximum precision {}",
716            n,
717            NUMERIC_DATUM_MAX_PRECISION
718        )
719    }
720    munge_numeric(n)?;
721
722    Ok(())
723}
724
725/// A type that can represent Real Numbers. Useful for interoperability between Numeric and
726/// floating point.
727pub trait DecimalLike:
728    From<u8>
729    + From<u16>
730    + From<u32>
731    + From<i8>
732    + From<i16>
733    + From<i32>
734    + From<f32>
735    + From<f64>
736    + std::ops::Add<Output = Self>
737    + std::ops::Sub<Output = Self>
738    + std::ops::Mul<Output = Self>
739    + std::ops::Div<Output = Self>
740{
741    /// Used to do value-to-value conversions while consuming the input value. Depending on the
742    /// implementation it may be potentially lossy.
743    fn lossy_from(i: i64) -> Self;
744}
745
746impl DecimalLike for f64 {
747    // No other known way to convert `i64` to `f64`.
748    #[allow(clippy::as_conversions)]
749    fn lossy_from(i: i64) -> Self {
750        i as f64
751    }
752}
753
754impl DecimalLike for Numeric {
755    fn lossy_from(i: i64) -> Self {
756        Numeric::from(i)
757    }
758}
759
760/// An encoded packed variant of [`Numeric`].
761///
762/// Unlike other "Packed" types we _DO NOT_ uphold the invariant that
763/// [`PackedNumeric`] sorts the same as [`Numeric`].
764#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
765pub struct PackedNumeric(pub [u8; 40]);
766
767impl FixedSizeCodec<Numeric> for PackedNumeric {
768    const SIZE: usize = 40;
769
770    fn as_bytes(&self) -> &[u8] {
771        &self.0
772    }
773
774    fn from_bytes(slice: &[u8]) -> Result<Self, String> {
775        let buf: [u8; Self::SIZE] = slice.try_into().map_err(|_| {
776            format!(
777                "size for PackedNumeric is {} bytes, got {}",
778                Self::SIZE,
779                slice.len()
780            )
781        })?;
782        Ok(PackedNumeric(buf))
783    }
784
785    fn from_value(val: Numeric) -> PackedNumeric {
786        let (digits, exponent, bits, lsu) = val.to_raw_parts();
787
788        let mut buf = [0u8; 40];
789
790        buf[0..4].copy_from_slice(&digits.to_le_bytes());
791        buf[4..8].copy_from_slice(&exponent.to_le_bytes());
792
793        for i in 0..13 {
794            buf[(i * 2) + 8..(i * 2) + 10].copy_from_slice(&lsu[i].to_le_bytes());
795        }
796
797        buf[34..35].copy_from_slice(&bits.to_le_bytes());
798
799        PackedNumeric(buf)
800    }
801
802    fn into_value(self) -> Numeric {
803        let digits: [u8; 4] = self.0[0..4].try_into().unwrap();
804        let digits = u32::from_le_bytes(digits);
805
806        let exponent: [u8; 4] = self.0[4..8].try_into().unwrap();
807        let exponent = i32::from_le_bytes(exponent);
808
809        let mut lsu = [0u16; 13];
810        for i in 0..13 {
811            let x: [u8; 2] = self.0[(i * 2) + 8..(i * 2) + 10].try_into().unwrap();
812            let x = u16::from_le_bytes(x);
813            lsu[i] = x;
814        }
815
816        let bits: [u8; 1] = self.0[34..35].try_into().unwrap();
817        let bits = u8::from_le_bytes(bits);
818
819        Numeric::from_raw_parts(digits, exponent, bits, lsu)
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use mz_ore::assert_ok;
826    use mz_proto::protobuf_roundtrip;
827    use proptest::prelude::*;
828
829    use crate::scalar::arb_numeric;
830
831    use super::*;
832
833    proptest! {
834        #[mz_ore::test]
835        fn numeric_max_scale_protobuf_roundtrip(expect in any::<NumericMaxScale>()) {
836            let actual = protobuf_roundtrip::<_, ProtoNumericMaxScale>(&expect);
837            assert_ok!(actual);
838            assert_eq!(actual.unwrap(), expect);
839        }
840
841        #[mz_ore::test]
842        fn optional_numeric_max_scale_protobuf_roundtrip(
843            expect in any::<Option<NumericMaxScale>>(),
844        ) {
845            let actual = protobuf_roundtrip::<_, ProtoOptionalNumericMaxScale>(&expect);
846            assert_ok!(actual);
847            assert_eq!(actual.unwrap(), expect);
848        }
849    }
850
851    #[mz_ore::test]
852    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
853    fn smoketest_packed_numeric_roundtrips() {
854        let og = PackedNumeric::from_value(Numeric::from(-42));
855        let bytes = og.as_bytes();
856        let rnd = PackedNumeric::from_bytes(bytes).expect("valid");
857        assert_eq!(og, rnd);
858
859        // Returns an error if the size of the slice is invalid.
860        mz_ore::assert_err!(PackedNumeric::from_bytes(&[0, 0, 0, 0]));
861    }
862
863    #[mz_ore::test]
864    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decNumberCopyNegate` on OS `linux`
865    fn proptest_packed_numeric_roundtrip() {
866        fn test(og: Numeric) {
867            let packed = PackedNumeric::from_value(og);
868            let rnd = packed.into_value();
869
870            if og.is_nan() && rnd.is_nan() {
871                return;
872            }
873            assert_eq!(og, rnd);
874        }
875
876        proptest!(|(num in arb_numeric())| {
877            test(num);
878        });
879    }
880
881    // Note: It's expected that this test will fail if you update the strategy
882    // for generating an arbitrary Numeric. In that case feel free to
883    // regenerate the snapshot.
884    #[mz_ore::test]
885    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decNumberCopyNegate` on OS `linux`
886    fn packed_numeric_stability() {
887        /// This is the seed [`proptest`] uses for their deterministic RNG. We
888        /// copy it here to prevent breaking this test if [`proptest`] changes.
889        const RNG_SEED: [u8; 32] = [
890            0xf4, 0x16, 0x16, 0x48, 0xc3, 0xac, 0x77, 0xac, 0x72, 0x20, 0x0b, 0xea, 0x99, 0x67,
891            0x2d, 0x6d, 0xca, 0x9f, 0x76, 0xaf, 0x1b, 0x09, 0x73, 0xa0, 0x59, 0x22, 0x6d, 0xc5,
892            0x46, 0x39, 0x1c, 0x4a,
893        ];
894
895        let rng = proptest::test_runner::TestRng::from_seed(
896            proptest::test_runner::RngAlgorithm::ChaCha,
897            &RNG_SEED,
898        );
899        // Generate a collection of Rows.
900        let config = proptest::test_runner::Config {
901            // We let the loop below drive how much data we generate.
902            cases: u32::MAX,
903            rng_algorithm: proptest::test_runner::RngAlgorithm::ChaCha,
904            ..Default::default()
905        };
906        let mut runner = proptest::test_runner::TestRunner::new_with_rng(config, rng);
907
908        let test_cases = 2_000;
909        let strat = arb_numeric();
910
911        let mut all_numerics = Vec::new();
912        for _ in 0..test_cases {
913            let value_tree = strat.new_tree(&mut runner).unwrap();
914            let numeric = value_tree.current();
915            let packed = PackedNumeric::from_value(numeric);
916            let hex_bytes = format!("{:x?}", packed.as_bytes());
917
918            all_numerics.push((numeric, hex_bytes));
919        }
920
921        insta::assert_debug_snapshot!(all_numerics);
922    }
923}