use std::error::Error;
use std::fmt;
use std::sync::LazyLock;
use byteorder::{NetworkEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use dec::OrderedDecimal;
use mz_ore::cast::CastFrom;
use mz_repr::adt::numeric::{self, cx_datum, Numeric as AdtNumeric, NumericAgg};
use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
#[derive(Debug)]
pub struct Numeric(pub OrderedDecimal<AdtNumeric>);
impl fmt::Display for Numeric {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<AdtNumeric> for Numeric {
fn from(n: AdtNumeric) -> Numeric {
Numeric(OrderedDecimal(n))
}
}
const TO_FROM_SQL_BASE_POW: u32 = 4;
static TO_SQL_BASER: LazyLock<AdtNumeric> =
LazyLock::new(|| AdtNumeric::from(10u32.pow(TO_FROM_SQL_BASE_POW)));
static FROM_SQL_SCALER: LazyLock<AdtNumeric> =
LazyLock::new(|| AdtNumeric::from(TO_FROM_SQL_BASE_POW));
const UNITS_LEN: usize = 11;
impl ToSql for Numeric {
fn to_sql(
&self,
_: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + 'static + Send + Sync>> {
let mut d = self.0 .0.clone();
let scale = u16::try_from(numeric::get_scale(&d))?;
let is_zero = d.is_zero();
let is_nan = d.is_nan();
let is_neg = d.is_negative() && !is_zero;
let is_infinite = d.is_infinite();
let mut cx = numeric::cx_datum();
cx.set_max_exponent(cx.max_exponent() + isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
.unwrap();
cx.set_min_exponent(cx.min_exponent() - isize::cast_from(i64::from(TO_FROM_SQL_BASE_POW)))
.unwrap();
cx.abs(&mut d);
let mut digits = [0u16; UNITS_LEN];
let mut d_i = UNITS_LEN;
let (fract_units, leading_zero_units) = if d.exponent() < 0 {
let pos_exp = usize::try_from(-d.exponent()).expect("positive value < 40");
let leading_zero_units = if pos_exp >= usize::cast_from(d.digits()) {
let digits = if d.is_zero() {
0
} else {
usize::cast_from(d.digits())
};
(pos_exp - digits + usize::cast_from(TO_FROM_SQL_BASE_POW) - 1)
/ usize::cast_from(TO_FROM_SQL_BASE_POW)
} else {
0
};
let s = pos_exp % usize::cast_from(TO_FROM_SQL_BASE_POW);
let unit_shift_exp = if s != 0 {
pos_exp + usize::cast_from(TO_FROM_SQL_BASE_POW) - s
} else {
pos_exp
};
cx.scaleb(&mut d, &AdtNumeric::from(unit_shift_exp));
(
u16::try_from(unit_shift_exp / usize::cast_from(TO_FROM_SQL_BASE_POW))
.expect("value < 40"),
leading_zero_units,
)
} else {
(0, 0)
};
let mut w = d.clone();
while !d.is_zero() && !d.is_special() {
d_i -= 1;
cx.rem(&mut d, &TO_SQL_BASER);
digits[d_i] =
u16::try_from(u32::try_from(d).expect("value < 10,000")).expect("value < 10,000");
cx.div_integer(&mut w, &TO_SQL_BASER);
d = w;
}
d_i -= leading_zero_units;
let units = u16::try_from(UNITS_LEN - d_i).unwrap();
let weight = if is_zero {
0
} else {
i16::try_from(units - fract_units).unwrap() - 1
};
out.put_u16(units);
out.put_i16(weight);
out.put_u16(if is_infinite {
if is_neg {
0xF000
} else {
0xD000
}
} else if is_nan {
0xC000
} else if is_neg {
0x4000
} else {
0
});
out.put_u16(scale);
if !is_nan {
for digit in digits[d_i..].iter() {
out.put_u16(*digit);
}
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
to_sql_checked!();
}
impl<'a> FromSql<'a> for Numeric {
fn from_sql(_: &Type, mut raw: &'a [u8]) -> Result<Numeric, Box<dyn Error + Sync + Send>> {
let units = raw.read_i16::<NetworkEndian>()?;
let weight = raw.read_i16::<NetworkEndian>()?;
let sign = raw.read_u16::<NetworkEndian>()?;
let in_scale = raw.read_i16::<NetworkEndian>()?;
let mut digits = vec![];
for _ in 0..units {
digits.push(raw.read_u16::<NetworkEndian>()?)
}
let mut cx = numeric::cx_agg();
let mut d = NumericAgg::zero();
let units_usize =
usize::try_from(units).map_err(|_| "units must not be negative: {units}")?;
for digit in digits[..units_usize].iter() {
cx.scaleb(&mut d, &FROM_SQL_SCALER);
let n = AdtNumeric::from(u32::from(*digit));
cx.add(&mut d, &n);
}
match sign {
0 => (),
0xD000 => return Ok(Numeric::from(AdtNumeric::infinity())),
0xF000 => {
let mut cx = numeric::cx_datum();
let mut d = AdtNumeric::infinity();
cx.neg(&mut d);
return Ok(Numeric::from(d));
}
0x4000 => cx.neg(&mut d),
0xC000 => return Ok(Numeric::from(AdtNumeric::nan())),
_ => return Err("bad sign in numeric".into()),
}
let mut scale = (units - weight - 1) * 4;
if scale < 0 {
cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale)));
scale = 0;
} else if scale > in_scale {
cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale - in_scale)));
scale = in_scale;
}
cx.scaleb(&mut d, &AdtNumeric::from(-i32::from(scale)));
cx.reduce(&mut d);
let mut cx = cx_datum();
let d_datum = cx.to_width(d);
if d.is_infinite() || cx.status().any() {
return Err(format!("Unable to take bytes to numeric value; rendered {}", d).into());
}
Ok(Numeric::from(d_datum))
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
}
#[mz_ore::test]
#[cfg_attr(miri, ignore)] fn test_to_from_sql_roundtrip() {
fn inner(s: &str) {
let mut cx = numeric::cx_datum();
let d = cx.parse(s).unwrap();
let r = Numeric(OrderedDecimal(d));
let mut out = BytesMut::new();
let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
assert_eq!(r.0, d_from_sql.0);
}
inner("0");
inner("-0");
inner("0.1");
inner("0.0");
inner("0.00");
inner("0.000");
inner("0.0000");
inner("0.00000");
inner("123456789.012346789");
inner("000000000000000000000000000000000000001");
inner("000000000000000000000000000000000000000");
inner("999999999999999999999999999999999999999");
inner("123456789012345678901234567890123456789");
inner("-123456789012345678901234567890123456789");
inner(".123456789012345678901234567890123456789");
inner(".000000000000000000000000000000000000001");
inner(".000000000000000000000000000000000000000");
inner(".999999999999999999999999999999999999999");
inner("-0.123456789012345678901234567890123456789");
inner("1e25");
inner("-1e25");
inner("9.876e-25");
inner("-9.876e-25");
inner("98760000");
inner(".00009876");
inner("-.00009876");
inner("NaN");
let mut cx = numeric::cx_datum();
let v = [
cx.parse("-999999999999999999999999999999999999999")
.unwrap(),
cx.parse("-999999999999999999999999999999999999999")
.unwrap(),
];
let s = cx.sum(v.iter());
assert!(s.is_infinite());
let r = Numeric::from(s);
let mut out = BytesMut::new();
let _ = r.to_sql(&Type::NUMERIC, &mut out).unwrap();
let d_from_sql = Numeric::from_sql(&Type::NUMERIC, &out).unwrap();
assert_eq!(r.0, d_from_sql.0);
}