tiberius/tds/
numeric.rs

1//! Representations of numeric types.
2
3use super::codec::Encode;
4use crate::{sql_read_bytes::SqlReadBytes, Error};
5#[cfg(feature = "bigdecimal")]
6#[cfg_attr(feature = "docs", doc(cfg(feature = "bigdecimal")))]
7pub use bigdecimal::{num_bigint::BigInt, BigDecimal};
8use byteorder::{ByteOrder, LittleEndian};
9use bytes::{BufMut, BytesMut};
10#[cfg(feature = "rust_decimal")]
11#[cfg_attr(feature = "docs", doc(cfg(feature = "rust_decimal")))]
12pub use rust_decimal::Decimal;
13use std::cmp::{Ordering, PartialEq};
14use std::fmt::{self, Debug, Display, Formatter};
15
16/// Represent a sql Decimal / Numeric type. It is stored in a i128 and has a
17/// maximum precision of 38 decimals.
18///
19/// A recommended way of dealing with numeric values is by enabling the
20/// `rust_decimal` feature and using its `Decimal` type instead.
21#[derive(Copy, Clone)]
22pub struct Numeric {
23    value: i128,
24    scale: u8,
25}
26
27impl Numeric {
28    /// Creates a new Numeric value.
29    ///
30    /// # Panic
31    /// It will panic if the scale exceed 37.
32    pub fn new_with_scale(value: i128, scale: u8) -> Self {
33        // scale cannot exceed 37 since a
34        // max precision of 38 is possible here.
35        assert!(scale < 38);
36
37        Numeric { value, scale }
38    }
39
40    /// Extract the decimal part.
41    pub fn dec_part(self) -> i128 {
42        let scale = self.pow_scale();
43        self.value - (self.value / scale) * scale
44    }
45
46    /// Extract the integer part.
47    pub fn int_part(self) -> i128 {
48        self.value / self.pow_scale()
49    }
50
51    #[inline]
52    fn pow_scale(self) -> i128 {
53        10i128.pow(self.scale as u32)
54    }
55
56    /// The scale (where is the decimal point) of the value.
57    #[inline]
58    pub fn scale(self) -> u8 {
59        self.scale
60    }
61
62    /// The internal integer value
63    #[inline]
64    pub fn value(self) -> i128 {
65        self.value
66    }
67
68    /// The precision of the `Number` as a number of digits.
69    pub fn precision(self) -> u8 {
70        let mut result = 0;
71        let mut n = self.int_part();
72
73        while n != 0 {
74            n /= 10;
75            result += 1;
76        }
77
78        if result == 0 {
79            1 + self.scale()
80        } else {
81            result + self.scale()
82        }
83    }
84
85    pub(crate) fn len(self) -> u8 {
86        match self.precision() {
87            1..=9 => 5,
88            10..=19 => 9,
89            20..=28 => 13,
90            _ => 17,
91        }
92    }
93
94    pub(crate) async fn decode<R>(src: &mut R, scale: u8) -> crate::Result<Option<Self>>
95    where
96        R: SqlReadBytes + Unpin,
97    {
98        fn decode_d128(buf: &[u8]) -> u128 {
99            let low_part = LittleEndian::read_u64(&buf[0..]) as u128;
100
101            if !buf[8..].iter().any(|x| *x != 0) {
102                return low_part;
103            }
104
105            let high_part = match buf.len() {
106                12 => LittleEndian::read_u32(&buf[8..]) as u128,
107                16 => LittleEndian::read_u64(&buf[8..]) as u128,
108                _ => unreachable!(),
109            };
110
111            // swap high&low for big endian
112            #[cfg(target_endian = "big")]
113            let (low_part, high_part) = (high_part, low_part);
114
115            let high_part = high_part * (u64::max_value() as u128 + 1);
116            low_part + high_part
117        }
118
119        let len = src.read_u8().await?;
120
121        if len == 0 {
122            Ok(None)
123        } else {
124            let sign = match src.read_u8().await? {
125                0 => -1i128,
126                1 => 1i128,
127                _ => return Err(Error::Protocol("decimal: invalid sign".into())),
128            };
129
130            let value = match len {
131                5 => src.read_u32_le().await? as i128 * sign,
132                9 => src.read_u64_le().await? as i128 * sign,
133                13 => {
134                    let mut bytes = [0u8; 12]; //u96
135                    for item in &mut bytes {
136                        *item = src.read_u8().await?;
137                    }
138                    decode_d128(&bytes) as i128 * sign
139                }
140                17 => {
141                    let mut bytes = [0u8; 16];
142                    for item in &mut bytes {
143                        *item = src.read_u8().await?;
144                    }
145                    decode_d128(&bytes) as i128 * sign
146                }
147                x => {
148                    return Err(Error::Protocol(
149                        format!("decimal/numeric: invalid length of {} received", x).into(),
150                    ))
151                }
152            };
153
154            Ok(Some(Numeric::new_with_scale(value, scale)))
155        }
156    }
157}
158
159impl Encode<BytesMut> for Numeric {
160    fn encode(self, dst: &mut BytesMut) -> crate::Result<()> {
161        dst.put_u8(self.len());
162
163        if self.value < 0 {
164            dst.put_u8(0);
165        } else {
166            dst.put_u8(1);
167        }
168
169        let value = self.value().abs();
170
171        match self.len() {
172            5 => dst.put_u32_le(value as u32),
173            9 => dst.put_u64_le(value as u64),
174            13 => {
175                dst.put_u64_le(value as u64);
176                dst.put_u32_le((value >> 64) as u32)
177            }
178            _ => dst.put_u128_le(value as u128),
179        }
180
181        Ok(())
182    }
183}
184
185impl Debug for Numeric {
186    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
187        write!(
188            f,
189            "{}.{:0pad$}",
190            self.int_part(),
191            self.dec_part(),
192            pad = self.scale as usize
193        )
194    }
195}
196
197impl Display for Numeric {
198    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
199        write!(f, "{:?}", self)
200    }
201}
202
203impl Eq for Numeric {}
204
205impl From<Numeric> for f64 {
206    fn from(n: Numeric) -> f64 {
207        n.dec_part() as f64 / n.pow_scale() as f64 + n.int_part() as f64
208    }
209}
210
211impl From<Numeric> for i128 {
212    fn from(n: Numeric) -> i128 {
213        n.int_part()
214    }
215}
216
217impl From<Numeric> for u128 {
218    fn from(n: Numeric) -> u128 {
219        n.int_part() as u128
220    }
221}
222
223impl PartialEq for Numeric {
224    fn eq(&self, other: &Self) -> bool {
225        match self.scale.cmp(&other.scale) {
226            Ordering::Greater => {
227                10i128.pow((self.scale - other.scale) as u32) * other.value == self.value
228            }
229            Ordering::Less => {
230                10i128.pow((other.scale - self.scale) as u32) * self.value == other.value
231            }
232            Ordering::Equal => self.value == other.value,
233        }
234    }
235}
236
237#[cfg(feature = "rust_decimal")]
238mod decimal {
239    use super::{Decimal, Numeric};
240    use crate::ColumnData;
241
242    #[cfg(feature = "tds73")]
243    from_sql!(Decimal: ColumnData::Numeric(ref num) => num.map(|num| {
244        Decimal::from_i128_with_scale(
245            num.value(),
246            num.scale() as u32,
247        )})
248    );
249
250    #[cfg(feature = "tds73")]
251    to_sql!(self_,
252            Decimal: (ColumnData::Numeric, {
253                let unpacked = self_.unpack();
254
255                let mut value = (((unpacked.hi as u128) << 64)
256                                 + ((unpacked.mid as u128) << 32)
257                                 + unpacked.lo as u128) as i128;
258
259                if self_.is_sign_negative() {
260                    value = -value;
261                }
262
263                Numeric::new_with_scale(value, self_.scale() as u8)
264            });
265    );
266}
267
268#[cfg(feature = "bigdecimal")]
269mod bigdecimal_ {
270    use super::{BigDecimal, BigInt, Numeric};
271    use crate::ColumnData;
272    use num_traits::ToPrimitive;
273    use std::convert::TryFrom;
274
275    #[cfg(feature = "tds73")]
276    from_sql!(BigDecimal: ColumnData::Numeric(ref num) => num.map(|num| {
277        let int = BigInt::from(num.value());
278
279        BigDecimal::new(int, num.scale() as i64)
280    }));
281
282    #[cfg(feature = "tds73")]
283    to_sql!(self_,
284            BigDecimal: (ColumnData::Numeric, {
285                let (int, exp) = self_.as_bigint_and_exponent();
286                // SQL Server cannot store negative scales, so we have
287                // to convert the number to the correct exponent
288                // before storing.
289                //
290                // E.g. `Decimal(9, -3)` would be stored as
291                // `Decimal(9000, 0)`.
292                let (int, exp) = if exp < 0 {
293                    self_.with_scale(0).into_bigint_and_exponent()
294                } else {
295                    (int, exp)
296                };
297
298                let value = int.to_i128().expect("Given BigDecimal overflowing the maximum accepted value.");
299
300                let scale = u8::try_from(std::cmp::max(exp, 0))
301                    .expect("Given BigDecimal exponent overflowing the maximum accepted scale (255).");
302
303                Numeric::new_with_scale(value, scale)
304            });
305    );
306
307    #[cfg(feature = "tds73")]
308    into_sql!(self_,
309            BigDecimal: (ColumnData::Numeric, {
310                let (int, exp) = self_.as_bigint_and_exponent();
311                // SQL Server cannot store negative scales, so we have
312                // to convert the number to the correct exponent
313                // before storing.
314                //
315                // E.g. `Decimal(9, -3)` would be stored as
316                // `Decimal(9000, 0)`.
317                let (int, exp) = if exp < 0 {
318                    self_.with_scale(0).into_bigint_and_exponent()
319                } else {
320                    (int, exp)
321                };
322                let value = int.to_i128().expect("Given BigDecimal overflowing the maximum accepted value.");
323
324                let scale = u8::try_from(std::cmp::max(exp, 0))
325                    .expect("Given BigDecimal exponent overflowing the maximum accepted scale (255).");
326
327                Numeric::new_with_scale(value, scale)
328            });
329    );
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn numeric_eq() {
338        assert_eq!(
339            Numeric {
340                value: 100501,
341                scale: 2
342            },
343            Numeric {
344                value: 1005010,
345                scale: 3
346            }
347        );
348        assert!(
349            Numeric {
350                value: 100501,
351                scale: 2
352            } != Numeric {
353                value: 10050,
354                scale: 1
355            }
356        );
357    }
358
359    #[test]
360    fn numeric_to_f64() {
361        assert_eq!(f64::from(Numeric::new_with_scale(57705, 2)), 577.05);
362    }
363
364    #[test]
365    fn numeric_to_int_dec_part() {
366        let n = Numeric::new_with_scale(57705, 2);
367        assert_eq!(n.int_part(), 577);
368        assert_eq!(n.dec_part(), 5);
369    }
370
371    #[test]
372    fn calculates_precision_correctly() {
373        let n = Numeric::new_with_scale(57705, 2);
374        assert_eq!(5, n.precision());
375    }
376
377    #[test]
378    #[cfg(feature = "bigdecimal")]
379    fn no_overflowing_pow() {
380        use crate::{ColumnData, ToSql};
381        use bigdecimal::FromPrimitive;
382
383        let dec = BigDecimal::new(BigInt::from_i8(1).unwrap(), -20);
384        let res = dec.to_sql();
385
386        assert_eq!(
387            ColumnData::Numeric(Some(Numeric::new_with_scale(100000000000000000000i128, 0))),
388            res
389        );
390    }
391}