rust_decimal/
str.rs

1use crate::{
2    constants::{BYTES_TO_OVERFLOW_U64, MAX_SCALE, MAX_STR_BUFFER_SIZE, OVERFLOW_U96, WILL_OVERFLOW_U64},
3    error::{tail_error, Error},
4    ops::array::{add_by_internal_flattened, add_one_internal, div_by_u32, is_all_zero, mul_by_u32},
5    Decimal,
6};
7
8use arrayvec::{ArrayString, ArrayVec};
9
10use alloc::{string::String, vec::Vec};
11use core::fmt;
12
13// impl that doesn't allocate for serialization purposes.
14pub(crate) fn to_str_internal(
15    value: &Decimal,
16    append_sign: bool,
17    precision: Option<usize>,
18) -> (ArrayString<MAX_STR_BUFFER_SIZE>, Option<usize>) {
19    // Get the scale - where we need to put the decimal point
20    let scale = value.scale() as usize;
21
22    // Convert to a string and manipulate that (neg at front, inject decimal)
23    let mut chars = ArrayVec::<_, MAX_STR_BUFFER_SIZE>::new();
24    let mut working = value.mantissa_array3();
25    while !is_all_zero(&working) {
26        let remainder = div_by_u32(&mut working, 10u32);
27        chars.push(char::from(b'0' + remainder as u8));
28    }
29    while scale > chars.len() {
30        chars.push('0');
31    }
32
33    let (prec, additional) = match precision {
34        Some(prec) => {
35            let max: usize = MAX_SCALE.into();
36            if prec > max {
37                (max, Some(prec - max))
38            } else {
39                (prec, None)
40            }
41        }
42        None => (scale, None),
43    };
44
45    let len = chars.len();
46    let whole_len = len - scale;
47    let mut rep = ArrayString::new();
48    // Append the negative sign if necessary while also keeping track of the length of an "empty" string representation
49    let empty_len = if append_sign && value.is_sign_negative() {
50        rep.push('-');
51        1
52    } else {
53        0
54    };
55    for i in 0..whole_len + prec {
56        if i == len - scale {
57            if i == 0 {
58                rep.push('0');
59            }
60            rep.push('.');
61        }
62
63        if i >= len {
64            rep.push('0');
65        } else {
66            let c = chars[len - i - 1];
67            rep.push(c);
68        }
69    }
70
71    // corner case for when we truncated everything in a low fractional
72    if rep.len() == empty_len {
73        rep.push('0');
74    }
75
76    (rep, additional)
77}
78
79pub(crate) fn fmt_scientific_notation(
80    value: &Decimal,
81    exponent_symbol: &str,
82    f: &mut fmt::Formatter<'_>,
83) -> fmt::Result {
84    #[cfg(not(feature = "std"))]
85    use alloc::string::ToString;
86
87    // Get the scale - this is the e value. With multiples of 10 this may get bigger.
88    let mut exponent = -(value.scale() as isize);
89
90    // Convert the integral to a string
91    let mut chars = Vec::new();
92    let mut working = value.mantissa_array3();
93    while !is_all_zero(&working) {
94        let remainder = div_by_u32(&mut working, 10u32);
95        chars.push(char::from(b'0' + remainder as u8));
96    }
97
98    // First of all, apply scientific notation rules. That is:
99    //  1. If non-zero digit comes first, move decimal point left so that e is a positive integer
100    //  2. If decimal point comes first, move decimal point right until after the first non-zero digit
101    // Since decimal notation naturally lends itself this way, we just need to inject the decimal
102    // point in the right place and adjust the exponent accordingly.
103
104    let len = chars.len();
105    let mut rep;
106    // We either are operating with a precision specified, or on defaults. Defaults will perform "smart"
107    // reduction of precision.
108    if let Some(precision) = f.precision() {
109        if len > 1 {
110            // If we're zero precision AND it's trailing zeros then strip them
111            if precision == 0 && chars.iter().take(len - 1).all(|c| *c == '0') {
112                rep = chars.iter().skip(len - 1).collect::<String>();
113            } else {
114                // We may still be zero precision, however we aren't trailing zeros
115                if precision > 0 {
116                    chars.insert(len - 1, '.');
117                }
118                rep = chars
119                    .iter()
120                    .rev()
121                    // Add on extra zeros according to the precision. At least one, since we added a decimal place.
122                    .chain(core::iter::repeat(&'0'))
123                    .take(if precision == 0 { 1 } else { 2 + precision })
124                    .collect::<String>();
125            }
126            exponent += (len - 1) as isize;
127        } else if precision > 0 {
128            // We have precision that we want to add
129            chars.push('.');
130            rep = chars
131                .iter()
132                .chain(core::iter::repeat(&'0'))
133                .take(2 + precision)
134                .collect::<String>();
135        } else {
136            rep = chars.iter().collect::<String>();
137        }
138    } else if len > 1 {
139        // If the number is just trailing zeros then we treat it like 0 precision
140        if chars.iter().take(len - 1).all(|c| *c == '0') {
141            rep = chars.iter().skip(len - 1).collect::<String>();
142        } else {
143            // Otherwise, we need to insert a decimal place and make it a scientific number
144            chars.insert(len - 1, '.');
145            rep = chars.iter().rev().collect::<String>();
146        }
147        exponent += (len - 1) as isize;
148    } else {
149        rep = chars.iter().collect::<String>();
150    }
151
152    rep.push_str(exponent_symbol);
153    rep.push_str(&exponent.to_string());
154    f.pad_integral(value.is_sign_positive(), "", &rep)
155}
156
157// dedicated implementation for the most common case.
158#[inline]
159pub(crate) fn parse_str_radix_10(str: &str) -> Result<Decimal, Error> {
160    let bytes = str.as_bytes();
161    if bytes.len() < BYTES_TO_OVERFLOW_U64 {
162        parse_str_radix_10_dispatch::<false, true>(bytes)
163    } else {
164        parse_str_radix_10_dispatch::<true, true>(bytes)
165    }
166}
167
168#[inline]
169pub(crate) fn parse_str_radix_10_exact(str: &str) -> Result<Decimal, Error> {
170    let bytes = str.as_bytes();
171    if bytes.len() < BYTES_TO_OVERFLOW_U64 {
172        parse_str_radix_10_dispatch::<false, false>(bytes)
173    } else {
174        parse_str_radix_10_dispatch::<true, false>(bytes)
175    }
176}
177
178#[inline]
179fn parse_str_radix_10_dispatch<const BIG: bool, const ROUND: bool>(bytes: &[u8]) -> Result<Decimal, Error> {
180    match bytes {
181        [b, rest @ ..] => byte_dispatch_u64::<false, false, false, BIG, true, ROUND>(rest, 0, 0, *b),
182        [] => tail_error("Invalid decimal: empty"),
183    }
184}
185
186#[inline]
187fn overflow_64(val: u64) -> bool {
188    val >= WILL_OVERFLOW_U64
189}
190
191#[inline]
192pub fn overflow_128(val: u128) -> bool {
193    val >= OVERFLOW_U96
194}
195
196/// Dispatch the next byte:
197///
198/// * POINT - a decimal point has been seen
199/// * NEG - we've encountered a `-` and the number is negative
200/// * HAS - a digit has been encountered (when HAS is false it's invalid)
201/// * BIG - a number that uses 96 bits instead of only 64 bits
202/// * FIRST - true if it is the first byte in the string
203#[inline]
204fn dispatch_next<const POINT: bool, const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>(
205    bytes: &[u8],
206    data64: u64,
207    scale: u8,
208) -> Result<Decimal, Error> {
209    if let Some((next, bytes)) = bytes.split_first() {
210        byte_dispatch_u64::<POINT, NEG, HAS, BIG, false, ROUND>(bytes, data64, scale, *next)
211    } else {
212        handle_data::<NEG, HAS>(data64 as u128, scale)
213    }
214}
215
216/// Dispatch the next non-digit byte:
217///
218/// * POINT - a decimal point has been seen
219/// * NEG - we've encountered a `-` and the number is negative
220/// * HAS - a digit has been encountered (when HAS is false it's invalid)
221/// * BIG - a number that uses 96 bits instead of only 64 bits
222/// * FIRST - true if it is the first byte in the string
223/// * ROUND - attempt to round underflow
224#[inline(never)]
225fn non_digit_dispatch_u64<
226    const POINT: bool,
227    const NEG: bool,
228    const HAS: bool,
229    const BIG: bool,
230    const FIRST: bool,
231    const ROUND: bool,
232>(
233    bytes: &[u8],
234    data64: u64,
235    scale: u8,
236    b: u8,
237) -> Result<Decimal, Error> {
238    match b {
239        b'-' if FIRST && !HAS => dispatch_next::<false, true, false, BIG, ROUND>(bytes, data64, scale),
240        b'+' if FIRST && !HAS => dispatch_next::<false, false, false, BIG, ROUND>(bytes, data64, scale),
241        b'_' if HAS => handle_separator::<POINT, NEG, BIG, ROUND>(bytes, data64, scale),
242        b => tail_invalid_digit(b),
243    }
244}
245
246#[inline]
247fn byte_dispatch_u64<
248    const POINT: bool,
249    const NEG: bool,
250    const HAS: bool,
251    const BIG: bool,
252    const FIRST: bool,
253    const ROUND: bool,
254>(
255    bytes: &[u8],
256    data64: u64,
257    scale: u8,
258    b: u8,
259) -> Result<Decimal, Error> {
260    match b {
261        b'0'..=b'9' => handle_digit_64::<POINT, NEG, BIG, ROUND>(bytes, data64, scale, b - b'0'),
262        b'.' if !POINT => handle_point::<NEG, HAS, BIG, ROUND>(bytes, data64, scale),
263        b => non_digit_dispatch_u64::<POINT, NEG, HAS, BIG, FIRST, ROUND>(bytes, data64, scale, b),
264    }
265}
266
267#[inline(never)]
268fn handle_digit_64<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>(
269    bytes: &[u8],
270    data64: u64,
271    scale: u8,
272    digit: u8,
273) -> Result<Decimal, Error> {
274    // we have already validated that we cannot overflow
275    let data64 = data64 * 10 + digit as u64;
276    let scale = if POINT { scale + 1 } else { 0 };
277
278    if let Some((next, bytes)) = bytes.split_first() {
279        let next = *next;
280        if POINT && BIG && scale >= 28 {
281            if ROUND {
282                maybe_round(data64 as u128, next, scale, POINT, NEG)
283            } else {
284                Err(Error::Underflow)
285            }
286        } else if BIG && overflow_64(data64) {
287            handle_full_128::<POINT, NEG, ROUND>(data64 as u128, bytes, scale, next)
288        } else {
289            byte_dispatch_u64::<POINT, NEG, true, BIG, false, ROUND>(bytes, data64, scale, next)
290        }
291    } else {
292        let data: u128 = data64 as u128;
293
294        handle_data::<NEG, true>(data, scale)
295    }
296}
297
298#[inline(never)]
299fn handle_point<const NEG: bool, const HAS: bool, const BIG: bool, const ROUND: bool>(
300    bytes: &[u8],
301    data64: u64,
302    scale: u8,
303) -> Result<Decimal, Error> {
304    dispatch_next::<true, NEG, HAS, BIG, ROUND>(bytes, data64, scale)
305}
306
307#[inline(never)]
308fn handle_separator<const POINT: bool, const NEG: bool, const BIG: bool, const ROUND: bool>(
309    bytes: &[u8],
310    data64: u64,
311    scale: u8,
312) -> Result<Decimal, Error> {
313    dispatch_next::<POINT, NEG, true, BIG, ROUND>(bytes, data64, scale)
314}
315
316#[inline(never)]
317#[cold]
318fn tail_invalid_digit(digit: u8) -> Result<Decimal, Error> {
319    match digit {
320        b'.' => tail_error("Invalid decimal: two decimal points"),
321        b'_' => tail_error("Invalid decimal: must start lead with a number"),
322        _ => tail_error("Invalid decimal: unknown character"),
323    }
324}
325
326#[inline(never)]
327#[cold]
328fn handle_full_128<const POINT: bool, const NEG: bool, const ROUND: bool>(
329    mut data: u128,
330    bytes: &[u8],
331    scale: u8,
332    next_byte: u8,
333) -> Result<Decimal, Error> {
334    let b = next_byte;
335    match b {
336        b'0'..=b'9' => {
337            let digit = u32::from(b - b'0');
338
339            // If the data is going to overflow then we should go into recovery mode
340            let next = (data * 10) + digit as u128;
341            if overflow_128(next) {
342                if !POINT {
343                    return tail_error("Invalid decimal: overflow from too many digits");
344                }
345
346                if ROUND {
347                    maybe_round(data, next_byte, scale, POINT, NEG)
348                } else {
349                    Err(Error::Underflow)
350                }
351            } else {
352                data = next;
353                let scale = scale + POINT as u8;
354                if let Some((next, bytes)) = bytes.split_first() {
355                    let next = *next;
356                    if POINT && scale >= 28 {
357                        if ROUND {
358                            // If it is an underscore at the rounding position we require slightly different handling to look ahead another digit
359                            if next == b'_' {
360                                // Skip consecutive underscores to find the next actual character
361                                let mut remaining_bytes = bytes;
362                                let mut next_char = None;
363                                while let Some((n, rest)) = remaining_bytes.split_first() {
364                                    if *n != b'_' {
365                                        next_char = Some(*n);
366                                        break;
367                                    }
368                                    remaining_bytes = rest;
369                                }
370
371                                if let Some(ch) = next_char {
372                                    // Skip underscores and use the next character for rounding
373                                    maybe_round(data, ch, scale, POINT, NEG)
374                                } else {
375                                    handle_data::<NEG, true>(data, scale)
376                                }
377                            } else {
378                                // Otherwise, we round as usual
379                                maybe_round(data, next, scale, POINT, NEG)
380                            }
381                        } else {
382                            Err(Error::Underflow)
383                        }
384                    } else {
385                        handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, next)
386                    }
387                } else {
388                    handle_data::<NEG, true>(data, scale)
389                }
390            }
391        }
392        b'.' if !POINT => {
393            // This call won't tail?
394            if let Some((next, bytes)) = bytes.split_first() {
395                handle_full_128::<true, NEG, ROUND>(data, bytes, scale, *next)
396            } else {
397                handle_data::<NEG, true>(data, scale)
398            }
399        }
400        b'_' => {
401            if let Some((next, bytes)) = bytes.split_first() {
402                handle_full_128::<POINT, NEG, ROUND>(data, bytes, scale, *next)
403            } else {
404                handle_data::<NEG, true>(data, scale)
405            }
406        }
407        b => tail_invalid_digit(b),
408    }
409}
410
411#[inline(never)]
412#[cold]
413fn maybe_round(mut data: u128, next_byte: u8, mut scale: u8, point: bool, negative: bool) -> Result<Decimal, Error> {
414    let digit = match next_byte {
415        b'0'..=b'9' => u32::from(next_byte - b'0'),
416        b'_' => 0, // This is perhaps an error case, but keep this here for compatibility
417        b'.' if !point => 0,
418        b => return tail_invalid_digit(b),
419    };
420
421    // Round at midpoint
422    if digit >= 5 {
423        data += 1;
424
425        // If the mantissa is now overflowing, round to the next
426        // next least significant digit and discard precision
427        if overflow_128(data) {
428            if scale == 0 {
429                return tail_error("Invalid decimal: overflow from mantissa after rounding");
430            }
431            data += 4;
432            data /= 10;
433            scale -= 1;
434        }
435    }
436
437    if negative {
438        handle_data::<true, true>(data, scale)
439    } else {
440        handle_data::<false, true>(data, scale)
441    }
442}
443
444#[inline(never)]
445fn tail_no_has() -> Result<Decimal, Error> {
446    tail_error("Invalid decimal: no digits found")
447}
448
449#[inline]
450fn handle_data<const NEG: bool, const HAS: bool>(data: u128, scale: u8) -> Result<Decimal, Error> {
451    debug_assert_eq!(data >> 96, 0);
452    if !HAS {
453        tail_no_has()
454    } else {
455        Ok(Decimal::from_parts(
456            data as u32,
457            (data >> 32) as u32,
458            (data >> 64) as u32,
459            NEG,
460            scale as u32,
461        ))
462    }
463}
464
465pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result<Decimal, Error> {
466    if str.is_empty() {
467        return Err(Error::from("Invalid decimal: empty"));
468    }
469    if radix < 2 {
470        return Err(Error::from("Unsupported radix < 2"));
471    }
472    if radix > 36 {
473        // As per trait documentation
474        return Err(Error::from("Unsupported radix > 36"));
475    }
476
477    let mut offset = 0;
478    let mut len = str.len();
479    let bytes = str.as_bytes();
480    let mut negative = false; // assume positive
481
482    // handle the sign
483    if bytes[offset] == b'-' {
484        negative = true; // leading minus means negative
485        offset += 1;
486        len -= 1;
487    } else if bytes[offset] == b'+' {
488        // leading + allowed
489        offset += 1;
490        len -= 1;
491    }
492
493    // should now be at numeric part of the significand
494    let mut digits_before_dot: i32 = -1; // digits before '.', -1 if no '.'
495    let mut coeff = ArrayVec::<_, 96>::new(); // integer significand array
496
497    // Supporting different radix
498    let (max_n, max_alpha_lower, max_alpha_upper) = if radix <= 10 {
499        (b'0' + (radix - 1) as u8, 0, 0)
500    } else {
501        let adj = (radix - 11) as u8;
502        (b'9', adj + b'a', adj + b'A')
503    };
504
505    // Estimate the max precision. All in all, it needs to fit into 96 bits.
506    // Rather than try to estimate, I've included the constants directly in here. We could,
507    // perhaps, replace this with a formula if it's faster - though it does appear to be log2.
508    let estimated_max_precision = match radix {
509        2 => 96,
510        3 => 61,
511        4 => 48,
512        5 => 42,
513        6 => 38,
514        7 => 35,
515        8 => 32,
516        9 => 31,
517        10 => 28,
518        11 => 28,
519        12 => 27,
520        13 => 26,
521        14 => 26,
522        15 => 25,
523        16 => 24,
524        17 => 24,
525        18 => 24,
526        19 => 23,
527        20 => 23,
528        21 => 22,
529        22 => 22,
530        23 => 22,
531        24 => 21,
532        25 => 21,
533        26 => 21,
534        27 => 21,
535        28 => 20,
536        29 => 20,
537        30 => 20,
538        31 => 20,
539        32 => 20,
540        33 => 20,
541        34 => 19,
542        35 => 19,
543        36 => 19,
544        _ => return Err(Error::from("Unsupported radix")),
545    };
546
547    let mut maybe_round = false;
548    while len > 0 {
549        let b = bytes[offset];
550        match b {
551            b'0'..=b'9' => {
552                if b > max_n {
553                    return Err(Error::from("Invalid decimal: invalid character"));
554                }
555                coeff.push(u32::from(b - b'0'));
556                offset += 1;
557                len -= 1;
558
559                // If the coefficient is longer than the max, exit early
560                if coeff.len() as u32 > estimated_max_precision {
561                    maybe_round = true;
562                    break;
563                }
564            }
565            b'a'..=b'z' => {
566                if b > max_alpha_lower {
567                    return Err(Error::from("Invalid decimal: invalid character"));
568                }
569                coeff.push(u32::from(b - b'a') + 10);
570                offset += 1;
571                len -= 1;
572
573                if coeff.len() as u32 > estimated_max_precision {
574                    maybe_round = true;
575                    break;
576                }
577            }
578            b'A'..=b'Z' => {
579                if b > max_alpha_upper {
580                    return Err(Error::from("Invalid decimal: invalid character"));
581                }
582                coeff.push(u32::from(b - b'A') + 10);
583                offset += 1;
584                len -= 1;
585
586                if coeff.len() as u32 > estimated_max_precision {
587                    maybe_round = true;
588                    break;
589                }
590            }
591            b'.' => {
592                if digits_before_dot >= 0 {
593                    return Err(Error::from("Invalid decimal: two decimal points"));
594                }
595                digits_before_dot = coeff.len() as i32;
596                offset += 1;
597                len -= 1;
598            }
599            b'_' => {
600                // Must start with a number...
601                if coeff.is_empty() {
602                    return Err(Error::from("Invalid decimal: must start lead with a number"));
603                }
604                offset += 1;
605                len -= 1;
606            }
607            _ => return Err(Error::from("Invalid decimal: unknown character")),
608        }
609    }
610
611    // If we exited before the end of the string then do some rounding if necessary
612    if maybe_round && offset < bytes.len() {
613        let next_byte = bytes[offset];
614        let digit = match next_byte {
615            b'0'..=b'9' => {
616                if next_byte > max_n {
617                    return Err(Error::from("Invalid decimal: invalid character"));
618                }
619                u32::from(next_byte - b'0')
620            }
621            b'a'..=b'z' => {
622                if next_byte > max_alpha_lower {
623                    return Err(Error::from("Invalid decimal: invalid character"));
624                }
625                u32::from(next_byte - b'a') + 10
626            }
627            b'A'..=b'Z' => {
628                if next_byte > max_alpha_upper {
629                    return Err(Error::from("Invalid decimal: invalid character"));
630                }
631                u32::from(next_byte - b'A') + 10
632            }
633            b'_' => 0,
634            b'.' => {
635                // Still an error if we have a second dp
636                if digits_before_dot >= 0 {
637                    return Err(Error::from("Invalid decimal: two decimal points"));
638                }
639                0
640            }
641            _ => return Err(Error::from("Invalid decimal: unknown character")),
642        };
643
644        // Round at midpoint
645        let midpoint = if radix & 0x1 == 1 { radix / 2 } else { (radix + 1) / 2 };
646        if digit >= midpoint {
647            let mut index = coeff.len() - 1;
648            loop {
649                let new_digit = coeff[index] + 1;
650                if new_digit <= 9 {
651                    coeff[index] = new_digit;
652                    break;
653                } else {
654                    coeff[index] = 0;
655                    if index == 0 {
656                        coeff.insert(0, 1u32);
657                        digits_before_dot += 1;
658                        coeff.pop();
659                        break;
660                    }
661                }
662                index -= 1;
663            }
664        }
665    }
666
667    // here when no characters left
668    if coeff.is_empty() {
669        return Err(Error::from("Invalid decimal: no digits found"));
670    }
671
672    let mut scale = if digits_before_dot >= 0 {
673        // we had a decimal place so set the scale
674        (coeff.len() as u32) - (digits_before_dot as u32)
675    } else {
676        0
677    };
678
679    // Parse this using specified radix
680    let mut data = [0u32, 0u32, 0u32];
681    let mut tmp = [0u32, 0u32, 0u32];
682    let len = coeff.len();
683    for (i, digit) in coeff.iter().enumerate() {
684        // If the data is going to overflow then we should go into recovery mode
685        tmp[0] = data[0];
686        tmp[1] = data[1];
687        tmp[2] = data[2];
688        let overflow = mul_by_u32(&mut tmp, radix);
689        if overflow > 0 {
690            // This means that we have more data to process, that we're not sure what to do with.
691            // This may or may not be an issue - depending on whether we're past a decimal point
692            // or not.
693            if (i as i32) < digits_before_dot && i + 1 < len {
694                return Err(Error::from("Invalid decimal: overflow from too many digits"));
695            }
696
697            if *digit >= 5 {
698                let carry = add_one_internal(&mut data);
699                if carry > 0 {
700                    // Highly unlikely scenario which is more indicative of a bug
701                    return Err(Error::from("Invalid decimal: overflow when rounding"));
702                }
703            }
704            // We're also one less digit so reduce the scale
705            let diff = (len - i) as u32;
706            if diff > scale {
707                return Err(Error::from("Invalid decimal: overflow from scale mismatch"));
708            }
709            scale -= diff;
710            break;
711        } else {
712            data[0] = tmp[0];
713            data[1] = tmp[1];
714            data[2] = tmp[2];
715            let carry = add_by_internal_flattened(&mut data, *digit);
716            if carry > 0 {
717                // Highly unlikely scenario which is more indicative of a bug
718                return Err(Error::from("Invalid decimal: overflow from carry"));
719            }
720        }
721    }
722
723    Ok(Decimal::from_parts(data[0], data[1], data[2], negative, scale))
724}
725
726#[cfg(test)]
727mod test {
728    use super::*;
729    use crate::Decimal;
730    use arrayvec::ArrayString;
731    use core::{fmt::Write, str::FromStr};
732
733    #[test]
734    fn display_does_not_overflow_max_capacity() {
735        let num = Decimal::from_str("1.2").unwrap();
736        let mut buffer = ArrayString::<64>::new();
737        buffer.write_fmt(format_args!("{num:.31}")).unwrap();
738        assert_eq!("1.2000000000000000000000000000000", buffer.as_str());
739    }
740
741    #[test]
742    fn from_str_rounding_0() {
743        assert_eq!(
744            parse_str_radix_10("1.234").unwrap().unpack(),
745            Decimal::new(1234, 3).unpack()
746        );
747    }
748
749    #[test]
750    fn from_str_rounding_1() {
751        assert_eq!(
752            parse_str_radix_10("11111_11111_11111.11111_11111_11111")
753                .unwrap()
754                .unpack(),
755            Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_111, 14).unpack()
756        );
757    }
758
759    #[test]
760    fn from_str_rounding_2() {
761        assert_eq!(
762            parse_str_radix_10("11111_11111_11111.11111_11111_11115")
763                .unwrap()
764                .unpack(),
765            Decimal::from_i128_with_scale(11_111_111_111_111_111_111_111_111_112, 14).unpack()
766        );
767    }
768
769    #[test]
770    fn from_str_rounding_3() {
771        assert_eq!(
772            parse_str_radix_10("11111_11111_11111.11111_11111_11195")
773                .unwrap()
774                .unpack(),
775            Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_1120, 14).unpack() // was Decimal::from_i128_with_scale(1_111_111_111_111_111_111_111_111_112, 13)
776        );
777    }
778
779    #[test]
780    fn from_str_rounding_4() {
781        assert_eq!(
782            parse_str_radix_10("99999_99999_99999.99999_99999_99995")
783                .unwrap()
784                .unpack(),
785            Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 13).unpack() // was Decimal::from_i128_with_scale(1_000_000_000_000_000_000_000_000_000, 12)
786        );
787    }
788
789    #[test]
790    fn from_str_no_rounding_0() {
791        assert_eq!(
792            parse_str_radix_10_exact("1.234").unwrap().unpack(),
793            Decimal::new(1234, 3).unpack()
794        );
795    }
796
797    #[test]
798    fn from_str_no_rounding_1() {
799        assert_eq!(
800            parse_str_radix_10_exact("11111_11111_11111.11111_11111_11111"),
801            Err(Error::Underflow)
802        );
803    }
804
805    #[test]
806    fn from_str_no_rounding_2() {
807        assert_eq!(
808            parse_str_radix_10_exact("11111_11111_11111.11111_11111_11115"),
809            Err(Error::Underflow)
810        );
811    }
812
813    #[test]
814    fn from_str_no_rounding_3() {
815        assert_eq!(
816            parse_str_radix_10_exact("11111_11111_11111.11111_11111_11195"),
817            Err(Error::Underflow)
818        );
819    }
820
821    #[test]
822    fn from_str_no_rounding_4() {
823        assert_eq!(
824            parse_str_radix_10_exact("99999_99999_99999.99999_99999_99995"),
825            Err(Error::Underflow)
826        );
827    }
828
829    #[test]
830    fn from_str_many_pointless_chars() {
831        assert_eq!(
832            parse_str_radix_10("00________________________________________________________________001.1")
833                .unwrap()
834                .unpack(),
835            Decimal::from_i128_with_scale(11, 1).unpack()
836        );
837    }
838
839    #[test]
840    fn from_str_leading_0s_1() {
841        assert_eq!(
842            parse_str_radix_10("00001.1").unwrap().unpack(),
843            Decimal::from_i128_with_scale(11, 1).unpack()
844        );
845    }
846
847    #[test]
848    fn from_str_leading_0s_2() {
849        assert_eq!(
850            parse_str_radix_10("00000_00000_00000_00000_00001.00001")
851                .unwrap()
852                .unpack(),
853            Decimal::from_i128_with_scale(100001, 5).unpack()
854        );
855    }
856
857    #[test]
858    fn from_str_leading_0s_3() {
859        assert_eq!(
860            parse_str_radix_10("0.00000_00000_00000_00000_00000_00100")
861                .unwrap()
862                .unpack(),
863            Decimal::from_i128_with_scale(1, 28).unpack()
864        );
865    }
866
867    #[test]
868    fn from_str_trailing_0s_1() {
869        assert_eq!(
870            parse_str_radix_10("0.00001_00000_00000").unwrap().unpack(),
871            Decimal::from_i128_with_scale(10_000_000_000, 15).unpack()
872        );
873    }
874
875    #[test]
876    fn from_str_trailing_0s_2() {
877        assert_eq!(
878            parse_str_radix_10("0.00001_00000_00000_00000_00000_00000")
879                .unwrap()
880                .unpack(),
881            Decimal::from_i128_with_scale(100_000_000_000_000_000_000_000, 28).unpack()
882        );
883    }
884
885    #[test]
886    fn from_str_overflow_1() {
887        assert_eq!(
888            parse_str_radix_10("99999_99999_99999_99999_99999_99999.99999"),
889            // The original implementation returned
890            //              Ok(10000_00000_00000_00000_00000_0000)
891            // Which is a bug!
892            Err(Error::from("Invalid decimal: overflow from too many digits"))
893        );
894    }
895
896    #[test]
897    fn from_str_overflow_2() {
898        assert!(
899            parse_str_radix_10("99999_99999_99999_99999_99999_11111.11111").is_err(),
900            // The original implementation is 'overflow from scale mismatch'
901            // but we got rid of that now
902        );
903    }
904
905    #[test]
906    fn from_str_overflow_3() {
907        assert!(
908            parse_str_radix_10("99999_99999_99999_99999_99999_99994").is_err() // We could not get into 'overflow when rounding' or 'overflow from carry'
909                                                                               // in the original implementation because the rounding logic before prevented it
910        );
911    }
912
913    #[test]
914    fn from_str_overflow_4() {
915        assert_eq!(
916            // This does not overflow, moving the decimal point 1 more step would result in
917            // 'overflow from too many digits'
918            parse_str_radix_10("99999_99999_99999_99999_99999_999.99")
919                .unwrap()
920                .unpack(),
921            Decimal::from_i128_with_scale(10_000_000_000_000_000_000_000_000_000, 0).unpack()
922        );
923    }
924
925    #[test]
926    fn from_str_mantissa_overflow_1() {
927        // reminder:
928        assert_eq!(OVERFLOW_U96, 79_228_162_514_264_337_593_543_950_336);
929        assert_eq!(
930            parse_str_radix_10("79_228_162_514_264_337_593_543_950_33.56")
931                .unwrap()
932                .unpack(),
933            Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 0).unpack()
934        );
935        // This is a mantissa of OVERFLOW_U96 - 1 just before reaching the last digit.
936        // Previously, this would return Err("overflow from mantissa after rounding")
937        // instead of successfully rounding.
938    }
939
940    #[test]
941    fn from_str_mantissa_overflow_2() {
942        assert_eq!(
943            parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.6"),
944            Err(Error::from("Invalid decimal: overflow from mantissa after rounding"))
945        );
946        // this case wants to round to 79_228_162_514_264_337_593_543_950_340.
947        // (79_228_162_514_264_337_593_543_950_336 is OVERFLOW_U96 and too large
948        // to fit in 96 bits) which is also too large for the mantissa so fails.
949    }
950
951    #[test]
952    fn from_str_mantissa_overflow_3() {
953        // this hits the other avoidable overflow case in maybe_round
954        assert_eq!(
955            parse_str_radix_10("7.92281625142643375935439503356").unwrap().unpack(),
956            Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack()
957        );
958    }
959
960    #[test]
961    fn from_str_mantissa_overflow_4() {
962        // Same test as above, however with underscores. This causes issues.
963        assert_eq!(
964            parse_str_radix_10("7.9_228_162_514_264_337_593_543_950_335_6")
965                .unwrap()
966                .unpack(),
967            Decimal::from_i128_with_scale(79_228_162_514_264_337_593_543_950_34, 27).unpack()
968        );
969    }
970
971    #[test]
972    fn invalid_input_1() {
973        assert_eq!(
974            parse_str_radix_10("1.0000000000000000000000000000.5"),
975            Err(Error::from("Invalid decimal: two decimal points"))
976        );
977    }
978
979    #[test]
980    fn invalid_input_2() {
981        assert_eq!(
982            parse_str_radix_10("1.0.5"),
983            Err(Error::from("Invalid decimal: two decimal points"))
984        );
985    }
986
987    #[test]
988    fn character_at_rounding_position() {
989        let tests = [
990            // digit is at the rounding position
991            (
992                "1.000_000_000_000_000_000_000_000_000_04",
993                Ok(Decimal::from_i128_with_scale(
994                    1_000_000_000_000_000_000_000_000_000_0,
995                    28,
996                )),
997            ),
998            (
999                "1.000_000_000_000_000_000_000_000_000_06",
1000                Ok(Decimal::from_i128_with_scale(
1001                    1_000_000_000_000_000_000_000_000_000_1,
1002                    28,
1003                )),
1004            ),
1005            // Decimal point is at the rounding position
1006            (
1007                "1_000_000_000_000_000_000_000_000_000_0.4",
1008                Ok(Decimal::from_i128_with_scale(
1009                    1_000_000_000_000_000_000_000_000_000_0,
1010                    0,
1011                )),
1012            ),
1013            (
1014                "1_000_000_000_000_000_000_000_000_000_0.6",
1015                Ok(Decimal::from_i128_with_scale(
1016                    1_000_000_000_000_000_000_000_000_000_1,
1017                    0,
1018                )),
1019            ),
1020            // Placeholder is at the rounding position
1021            (
1022                "1.000_000_000_000_000_000_000_000_000_0_4",
1023                Ok(Decimal::from_i128_with_scale(
1024                    1_000_000_000_000_000_000_000_000_000_0,
1025                    28,
1026                )),
1027            ),
1028            (
1029                "1.000_000_000_000_000_000_000_000_000_0_6",
1030                Ok(Decimal::from_i128_with_scale(
1031                    1_000_000_000_000_000_000_000_000_000_1,
1032                    28,
1033                )),
1034            ),
1035            // Multiple placeholders at rounding position
1036            (
1037                "1.000_000_000_000_000_000_000_000_000_0__4",
1038                Ok(Decimal::from_i128_with_scale(
1039                    1_000_000_000_000_000_000_000_000_000_0,
1040                    28,
1041                )),
1042            ),
1043            (
1044                "1.000_000_000_000_000_000_000_000_000_0__6",
1045                Ok(Decimal::from_i128_with_scale(
1046                    1_000_000_000_000_000_000_000_000_000_1,
1047                    28,
1048                )),
1049            ),
1050            (
1051                "1.234567890123456789012345678_9",
1052                Ok(Decimal::from_i128_with_scale(12345678901234567890123456789, 28)),
1053            ),
1054            (
1055                "0.234567890123456789012345678_9",
1056                Ok(Decimal::from_i128_with_scale(2345678901234567890123456789, 28)),
1057            ),
1058            (
1059                "0.1234567890123456789012345678_9",
1060                Ok(Decimal::from_i128_with_scale(1234567890123456789012345679, 28)),
1061            ),
1062            (
1063                "0.1234567890123456789012345678_4",
1064                Ok(Decimal::from_i128_with_scale(1234567890123456789012345678, 28)),
1065            ),
1066        ];
1067
1068        for (input, expected) in tests.iter() {
1069            assert_eq!(parse_str_radix_10(input), *expected, "Test input {}", input);
1070        }
1071    }
1072
1073    #[test]
1074    fn from_str_edge_cases_1() {
1075        assert_eq!(parse_str_radix_10(""), Err(Error::from("Invalid decimal: empty")));
1076    }
1077
1078    #[test]
1079    fn from_str_edge_cases_2() {
1080        assert_eq!(
1081            parse_str_radix_10("0.1."),
1082            Err(Error::from("Invalid decimal: two decimal points"))
1083        );
1084    }
1085
1086    #[test]
1087    fn from_str_edge_cases_3() {
1088        assert_eq!(
1089            parse_str_radix_10("_"),
1090            Err(Error::from("Invalid decimal: must start lead with a number"))
1091        );
1092    }
1093
1094    #[test]
1095    fn from_str_edge_cases_4() {
1096        assert_eq!(
1097            parse_str_radix_10("1?2"),
1098            Err(Error::from("Invalid decimal: unknown character"))
1099        );
1100    }
1101
1102    #[test]
1103    fn from_str_edge_cases_5() {
1104        assert_eq!(
1105            parse_str_radix_10("."),
1106            Err(Error::from("Invalid decimal: no digits found"))
1107        );
1108    }
1109
1110    #[test]
1111    fn from_str_edge_cases_6() {
1112        // Decimal::MAX + 0.99999
1113        assert_eq!(
1114            parse_str_radix_10("79_228_162_514_264_337_593_543_950_335.99999"),
1115            Err(Error::from("Invalid decimal: overflow from mantissa after rounding"))
1116        );
1117    }
1118}