Skip to main content

bnum/buint/
mod.rs

1use crate::errors::{self, option_expect};
2
3use crate::digit;
4use crate::doc;
5use crate::ExpType;
6// use core::mem::MaybeUninit;
7
8#[cfg(feature = "serde")]
9use ::{
10    serde::{Deserialize, Serialize},
11    serde_big_array::BigArray,
12};
13
14#[cfg(feature = "borsh")]
15use ::{
16    alloc::string::ToString,
17    borsh::{BorshDeserialize, BorshSchema, BorshSerialize},
18};
19
20use core::default::Default;
21
22use core::iter::{Iterator, Product, Sum};
23
24macro_rules! mod_impl {
25    ($BUint: ident, $BInt: ident, $Digit: ident) => {
26        /// Unsigned integer type composed of
27        #[doc = concat!("`", stringify!($Digit), "`")]
28        /// digits, of arbitrary fixed size which must be known at compile time.
29        ///
30        /// Digits are stored in little endian (least significant digit first). This integer type aims to exactly replicate the behaviours of Rust's built-in unsigned integer types: `u8`, `u16`, `u32`, `u64`, `u128` and `usize`. The const generic parameter `N` is the number of
31        #[doc = concat!("`", stringify!($Digit), "`")]
32        /// digits that are stored.
33        ///
34        #[doc = doc::arithmetic_doc!($BUint)]
35
36        #[derive(Clone, Copy, Hash, PartialEq, Eq)]
37        #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
38        #[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize, BorshSchema))]
39        #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
40        #[cfg_attr(feature = "valuable", derive(valuable::Valuable))]
41        #[repr(transparent)]
42        pub struct $BUint<const N: usize> {
43            #[cfg_attr(feature = "serde", serde(with = "BigArray"))]
44            pub(crate) digits: [$Digit; N],
45        }
46
47        #[cfg(feature = "zeroize")]
48        impl<const N: usize> zeroize::DefaultIsZeroes for $BUint<N> {}
49
50        impl<const N: usize> $BUint<N> {
51            #[doc = doc::count_ones!(U 1024)]
52            #[must_use = doc::must_use_op!()]
53            #[inline]
54            pub const fn count_ones(self) -> ExpType {
55                let mut ones = 0;
56                let mut i = 0;
57                while i < N {
58                    ones += self.digits[i].count_ones() as ExpType;
59                    i += 1;
60                }
61                ones
62            }
63
64            #[doc = doc::count_zeros!(U 1024)]
65            #[must_use = doc::must_use_op!()]
66            #[inline]
67            pub const fn count_zeros(self) -> ExpType {
68                let mut zeros = 0;
69                let mut i = 0;
70                while i < N {
71                    zeros += self.digits[i].count_zeros() as ExpType;
72                    i += 1;
73                }
74                zeros
75            }
76
77            #[doc = doc::leading_zeros!(U 1024)]
78            #[must_use = doc::must_use_op!()]
79            #[inline]
80            pub const fn leading_zeros(self) -> ExpType {
81                let mut zeros = 0;
82                let mut i = N;
83                while i > 0 {
84                    i -= 1;
85                    let digit = self.digits[i];
86                    zeros += digit.leading_zeros() as ExpType;
87                    if digit != $Digit::MIN {
88                        break;
89                    }
90                }
91                zeros
92            }
93
94            #[doc = doc::trailing_zeros!(U 1024)]
95            #[must_use = doc::must_use_op!()]
96            #[inline]
97            pub const fn trailing_zeros(self) -> ExpType {
98                let mut zeros = 0;
99                let mut i = 0;
100                while i < N {
101                    let digit = self.digits[i];
102                    zeros += digit.trailing_zeros() as ExpType;
103                    if digit != $Digit::MIN {
104                        break;
105                    }
106                    i += 1;
107                }
108                zeros
109            }
110
111            #[doc = doc::leading_ones!(U 1024, MAX)]
112            #[must_use = doc::must_use_op!()]
113            #[inline]
114            pub const fn leading_ones(self) -> ExpType {
115                let mut ones = 0;
116                let mut i = N;
117                while i > 0 {
118                    i -= 1;
119                    let digit = self.digits[i];
120                    ones += digit.leading_ones() as ExpType;
121                    if digit != $Digit::MAX {
122                        break;
123                    }
124                }
125                ones
126            }
127
128            #[doc = doc::trailing_ones!(U 1024)]
129            #[must_use = doc::must_use_op!()]
130            #[inline]
131            pub const fn trailing_ones(self) -> ExpType {
132                let mut ones = 0;
133                let mut i = 0;
134                while i < N {
135                    let digit = self.digits[i];
136                    ones += digit.trailing_ones() as ExpType;
137                    if digit != $Digit::MAX {
138                        break;
139                    }
140                    i += 1;
141                }
142                ones
143            }
144
145            #[doc = doc::cast_signed!(U)]
146            #[must_use = doc::must_use_op!()]
147            #[inline]
148            pub const fn cast_signed(self) -> $BInt<N> {
149                $BInt::<N>::from_bits(self)
150            }
151
152            #[inline]
153            const unsafe fn rotate_digits_left(self, n: usize) -> Self {
154                let mut out = Self::ZERO;
155                let mut i = n;
156                while i < N {
157                    out.digits[i] = self.digits[i - n];
158                    i += 1;
159                }
160                let init_index = N - n;
161                let mut i = init_index;
162                while i < N {
163                    out.digits[i - init_index] = self.digits[i];
164                    i += 1;
165                }
166
167                out
168            }
169
170            #[inline]
171            const unsafe fn unchecked_rotate_left(self, rhs: ExpType) -> Self {
172                let digit_shift = (rhs >> digit::$Digit::BIT_SHIFT) as usize;
173                let bit_shift = rhs & digit::$Digit::BITS_MINUS_1;
174
175                let mut out = self.rotate_digits_left(digit_shift);
176
177                if bit_shift != 0 {
178                    let carry_shift = digit::$Digit::BITS - bit_shift;
179                    let mut carry = 0;
180
181                    let mut i = 0;
182                    while i < N {
183                        let current_digit = out.digits[i];
184                        out.digits[i] = (current_digit << bit_shift) | carry;
185                        carry = current_digit >> carry_shift;
186                        i += 1;
187                    }
188                    out.digits[0] |= carry;
189                }
190
191                out
192            }
193
194            const BITS_MINUS_1: ExpType = (Self::BITS - 1) as ExpType;
195
196            #[doc = doc::rotate_left!(U 256, "u")]
197            #[must_use = doc::must_use_op!()]
198            #[inline]
199            pub const fn rotate_left(self, n: ExpType) -> Self {
200                unsafe {
201                    self.unchecked_rotate_left(n & Self::BITS_MINUS_1)
202                }
203            }
204
205            #[doc = doc::rotate_right!(U 256, "u")]
206            #[must_use = doc::must_use_op!()]
207            #[inline]
208            pub const fn rotate_right(self, n: ExpType) -> Self {
209                let n = n & Self::BITS_MINUS_1;
210                unsafe {
211                    self.unchecked_rotate_left(Self::BITS as ExpType - n)
212                }
213            }
214
215            const N_MINUS_1: usize = N - 1;
216
217            #[doc = doc::swap_bytes!(U 256, "u")]
218            #[must_use = doc::must_use_op!()]
219            #[inline]
220            pub const fn swap_bytes(self) -> Self {
221                let mut uint = Self::ZERO;
222                let mut i = 0;
223                while i < N {
224                    uint.digits[i] = self.digits[Self::N_MINUS_1 - i].swap_bytes();
225                    i += 1;
226                }
227                uint
228            }
229
230            #[doc = doc::reverse_bits!(U 256, "u")]
231            #[must_use = doc::must_use_op!()]
232            #[inline]
233            pub const fn reverse_bits(self) -> Self {
234                let mut uint = Self::ZERO;
235                let mut i = 0;
236                while i < N {
237                    uint.digits[i] = self.digits[Self::N_MINUS_1 - i].reverse_bits();
238                    i += 1;
239                }
240                uint
241            }
242
243            #[doc = doc::pow!(U 256)]
244            #[must_use = doc::must_use_op!()]
245            #[inline]
246            pub const fn pow(self, exp: ExpType) -> Self {
247                #[cfg(debug_assertions)]
248                return self.strict_pow(exp);
249
250                #[cfg(not(debug_assertions))]
251                self.wrapping_pow(exp)
252            }
253
254            #[doc = doc::div_euclid!(U)]
255            #[must_use = doc::must_use_op!()]
256            #[inline]
257            pub const fn div_euclid(self, rhs: Self) -> Self {
258                self.wrapping_div_euclid(rhs)
259            }
260
261
262            #[doc = doc::rem_euclid!(U)]
263            #[must_use = doc::must_use_op!()]
264            #[inline]
265            pub const fn rem_euclid(self, rhs: Self) -> Self {
266                self.wrapping_rem_euclid(rhs)
267            }
268
269            #[doc = doc::doc_comment! {
270                U 256,
271                "Returns `true` if and only if `self == 2^k` for some integer `k`.",
272
273                "let n = " stringify!(U256) "::from(1u16 << 14);\n"
274                "assert!(n.is_power_of_two());\n"
275                "let m = " stringify!(U256) "::from(100u8);\n"
276                "assert!(!m.is_power_of_two());"
277            }]
278            #[must_use]
279            #[inline]
280            pub const fn is_power_of_two(self) -> bool {
281                let mut i = 0;
282                let mut ones = 0;
283                while i < N {
284                    ones += (&self.digits)[i].count_ones();
285                    if ones > 1 {
286                        return false;
287                    }
288                    i += 1;
289                }
290                ones == 1
291            }
292
293            #[doc = doc::next_power_of_two!(U 256, "0", "ZERO")]
294            #[must_use = doc::must_use_op!()]
295            #[inline]
296            pub const fn next_power_of_two(self) -> Self {
297                #[cfg(debug_assertions)]
298                return option_expect!(
299                    self.checked_next_power_of_two(),
300                    errors::err_msg!("attempt to calculate next power of two with overflow")
301                );
302                #[cfg(not(debug_assertions))]
303                self.wrapping_next_power_of_two()
304            }
305
306            #[doc = doc::midpoint!(U)]
307            #[must_use = doc::must_use_op!()]
308            #[inline]
309            pub const fn midpoint(self, rhs: Self) -> Self {
310                // see section 2.5: Average of Two Integers in Hacker's Delight
311                self.bitand(rhs).add(self.bitxor(rhs).shr(1))
312            }
313
314            #[doc = doc::ilog2!(U)]
315            #[must_use = doc::must_use_op!()]
316            #[inline]
317            pub const fn ilog2(self) -> ExpType {
318                option_expect!(
319                    self.checked_ilog2(),
320                    errors::err_msg!(errors::non_positive_log_message!())
321                )
322            }
323
324            #[doc = doc::ilog10!(U)]
325            #[must_use = doc::must_use_op!()]
326            #[inline]
327            pub const fn ilog10(self) -> ExpType {
328                option_expect!(
329                    self.checked_ilog10(),
330                    errors::err_msg!(errors::non_positive_log_message!())
331                )
332            }
333
334            #[doc = doc::ilog!(U)]
335            #[must_use = doc::must_use_op!()]
336            #[inline]
337            pub const fn ilog(self, base: Self) -> ExpType {
338                if base.le(&Self::ONE) {
339                    panic!("{}", errors::err_msg!(errors::invalid_log_base!()));
340                }
341                option_expect!(
342                    self.checked_ilog(base), errors::err_msg!(errors::non_positive_log_message!())
343                )
344            }
345
346            #[doc = doc::abs_diff!(U)]
347            #[must_use = doc::must_use_op!()]
348            #[inline]
349            pub const fn abs_diff(self, other: Self) -> Self {
350                if self.lt(&other) {
351                    other.wrapping_sub(self)
352                } else {
353                    self.wrapping_sub(other)
354                }
355            }
356
357            #[doc = doc::next_multiple_of!(U)]
358            #[must_use = doc::must_use_op!()]
359            #[inline]
360            pub const fn next_multiple_of(self, rhs: Self) -> Self {
361                let rem = self.wrapping_rem(rhs);
362                if rem.is_zero() {
363                    self
364                } else {
365                    self.add(rhs.sub(rem))
366                }
367            }
368
369            #[doc = doc::div_floor!(U)]
370            #[must_use = doc::must_use_op!()]
371            #[inline]
372            pub const fn div_floor(self, rhs: Self) -> Self {
373                self.wrapping_div(rhs)
374            }
375
376            #[doc = doc::div_ceil!(U)]
377            #[must_use = doc::must_use_op!()]
378            #[inline]
379            pub const fn div_ceil(self, rhs: Self) -> Self {
380                let (div, rem) = self.div_rem(rhs);
381                if rem.is_zero() {
382                    div
383                } else {
384                    div.add(Self::ONE)
385                }
386            }
387        }
388
389        impl<const N: usize> $BUint<N> {
390            #[inline]
391            pub(crate) const unsafe fn unchecked_shl_internal(self, rhs: ExpType) -> Self {
392                let mut out = $BUint::ZERO;
393                let digit_shift = (rhs >> digit::$Digit::BIT_SHIFT) as usize;
394                let bit_shift = rhs & digit::$Digit::BITS_MINUS_1;
395
396                // let num_copies = N.saturating_sub(digit_shift); // TODO: use unchecked_ methods from primitives when these are stablised and constified
397
398                if bit_shift != 0 {
399                    let carry_shift = digit::$Digit::BITS - bit_shift;
400                    let mut carry = 0;
401
402                    let mut i = digit_shift;
403                    while i < N {
404                        let current_digit = self.digits[i - digit_shift];
405                        out.digits[i] = (current_digit << bit_shift) | carry;
406                        carry = current_digit >> carry_shift;
407                        i += 1;
408                    }
409                } else {
410                    let mut i = digit_shift;
411                    while i < N { // we start i at digit_shift, not 0, since the compiler can elide bounds checks when i < N
412                        out.digits[i] = self.digits[i - digit_shift];
413                        i += 1;
414                    }
415                }
416
417                out
418            }
419
420            #[inline]
421            pub(crate) const unsafe fn unchecked_shr_pad_internal<const NEG: bool>(self, rhs: ExpType) -> Self {
422                let mut out = if NEG {
423                    $BUint::MAX
424                } else {
425                    $BUint::ZERO
426                };
427                let digit_shift = (rhs >> digit::$Digit::BIT_SHIFT) as usize;
428                let bit_shift = rhs & digit::$Digit::BITS_MINUS_1;
429
430                let num_copies = N.saturating_sub(digit_shift); // TODO: use unchecked_ methods from primitives when these are stablised and constified
431
432                if bit_shift != 0 {
433                    let carry_shift = digit::$Digit::BITS - bit_shift;
434                    let mut carry = 0;
435
436                    let mut i = digit_shift;
437                    while i < N { // we use an increment while loop because the compiler can elide the array bounds check, which results in big performance gains
438                        let index = N - 1 - i;
439                        let current_digit = self.digits[index + digit_shift];
440                        out.digits[index] = (current_digit >> bit_shift) | carry;
441                        carry = current_digit << carry_shift;
442                        i += 1;
443                    }
444
445                    if NEG {
446                        out.digits[num_copies - 1] |= $Digit::MAX << carry_shift;
447                    }
448                } else {
449                    let mut i = digit_shift;
450                    while i < N { // we start i at digit_shift, not 0, since the compiler can elide bounds checks when i < N
451                        out.digits[i - digit_shift] = self.digits[i];
452                        i += 1;
453                    }
454                }
455
456                out
457            }
458
459            pub(crate) const unsafe fn unchecked_shr_internal(u: $BUint<N>, rhs: ExpType) -> $BUint<N> {
460                Self::unchecked_shr_pad_internal::<false>(u, rhs)
461            }
462
463            #[doc = doc::bits!(U 256)]
464            #[must_use]
465            #[inline]
466            pub const fn bits(&self) -> ExpType {
467                Self::BITS as ExpType - self.leading_zeros()
468            }
469
470            #[doc = doc::bit!(U 256)]
471            #[must_use]
472            #[inline]
473            pub const fn bit(&self, index: ExpType) -> bool {
474                let digit = self.digits[index as usize >> digit::$Digit::BIT_SHIFT];
475                digit & (1 << (index & digit::$Digit::BITS_MINUS_1)) != 0
476            }
477
478            #[doc = doc::set_bit!(U 256)]
479            #[inline]
480            pub fn set_bit(&mut self, index: ExpType, value: bool) {
481                let digit = &mut self.digits[index as usize >> digit::$Digit::BIT_SHIFT];
482                let shift = index & digit::$Digit::BITS_MINUS_1;
483                if value {
484                    *digit |= (1 << shift);
485                } else {
486                    *digit &= !(1 << shift);
487                }
488            }
489
490            /// Returns an integer whose value is `2^power`. This is faster than using a shift left on `Self::ONE`.
491            ///
492            /// # Panics
493            ///
494            /// This function will panic if `power` is greater than or equal to `Self::BITS`.
495            ///
496            /// # Examples
497            ///
498            /// ```
499            /// use bnum::types::U256;
500            ///
501            /// let power = 11;
502            /// assert_eq!(U256::power_of_two(11), (1u128 << 11).into());
503            /// ```
504            #[must_use]
505            #[inline]
506            pub const fn power_of_two(power: ExpType) -> Self {
507                let mut out = Self::ZERO;
508                out.digits[power as usize >> digit::$Digit::BIT_SHIFT] = 1 << (power & (digit::$Digit::BITS - 1));
509                out
510            }
511
512            // #[inline(always)]
513            // pub(crate) const fn digit(&self, index: usize) -> $Digit {
514            //     self.digits[index]
515            // }
516
517            /// Returns the digits stored in `self` as an array. Digits are little endian (least significant digit first).
518            #[must_use]
519            #[inline(always)]
520            pub const fn digits(&self) -> &[$Digit; N] {
521                &self.digits
522            }
523
524            /// Returns the digits stored in `self` as a mutable array. Digits are little endian (least significant digit first).
525            #[must_use]
526            #[inline(always)]
527            pub fn digits_mut(&mut self) -> &mut [$Digit; N] {
528                &mut self.digits
529            }
530
531            /// Creates a new unsigned integer from the given array of digits. Digits are stored as little endian (least significant digit first).
532            #[must_use]
533            #[inline(always)]
534            pub const fn from_digits(digits: [$Digit; N]) -> Self {
535                Self { digits }
536            }
537
538            /// Creates a new unsigned integer from the given digit. The given digit is stored as the least significant digit.
539            #[must_use]
540            #[inline(always)]
541            pub const fn from_digit(digit: $Digit) -> Self {
542                let mut out = Self::ZERO;
543                out.digits[0] = digit;
544                out
545            }
546
547            #[doc = doc::is_zero!(U 256)]
548            #[must_use]
549            #[inline]
550            pub const fn is_zero(&self) -> bool {
551                let mut i = 0;
552                while i < N {
553                    if (&self.digits)[i] != 0 {
554                        return false;
555                    }
556                    i += 1;
557                }
558                true
559            }
560
561            #[doc = doc::is_one!(U 256)]
562            #[must_use]
563            #[inline]
564            pub const fn is_one(&self) -> bool {
565                if N == 0 || self.digits[0] != 1 {
566                    return false;
567                }
568                let mut i = 1;
569                while i < N {
570                    if (&self.digits)[i] != 0 {
571                        return false;
572                    }
573                    i += 1;
574                }
575                true
576            }
577
578            #[inline]
579            pub(crate) const fn last_digit_index(&self) -> usize {
580                let mut index = 0;
581                let mut i = 1;
582                while i < N {
583                    if (&self.digits)[i] != 0 {
584                        index = i;
585                    }
586                    i += 1;
587                }
588                index
589            }
590
591            #[allow(unused)]
592            #[inline]
593            fn square(self) -> Self {
594                // TODO: optimise this method, this will make exponentiation by squaring faster
595                self * self
596            }
597        }
598
599        impl<const N: usize> Default for $BUint<N> {
600            #[doc = doc::default!()]
601            #[inline]
602            fn default() -> Self {
603                Self::ZERO
604            }
605        }
606
607        impl<const N: usize> Product<Self> for $BUint<N> {
608            #[inline]
609            fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
610                iter.fold(Self::ONE, |a, b| a * b)
611            }
612        }
613
614        impl<'a, const N: usize> Product<&'a Self> for $BUint<N> {
615            #[inline]
616            fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
617                iter.fold(Self::ONE, |a, b| a * b)
618            }
619        }
620
621        impl<const N: usize> Sum<Self> for $BUint<N> {
622            #[inline]
623            fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
624                iter.fold(Self::ZERO, |a, b| a + b)
625            }
626        }
627
628        impl<'a, const N: usize> Sum<&'a Self> for $BUint<N> {
629            #[inline]
630            fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
631                iter.fold(Self::ZERO, |a, b| a + b)
632            }
633        }
634
635        #[cfg(any(test, feature = "quickcheck"))]
636        impl<const N: usize> quickcheck::Arbitrary for $BUint<N> {
637            fn arbitrary(g: &mut quickcheck::Gen) -> Self {
638                let mut out = Self::ZERO;
639                for digit in out.digits.iter_mut() {
640                    *digit = <$Digit as quickcheck::Arbitrary>::arbitrary(g);
641                }
642                out
643            }
644        }
645    };
646}
647
648#[cfg(test)]
649crate::test::all_digit_tests! {
650    use crate::test::{debug_skip, test_bignum, types::utest};
651
652    crate::int::tests!(utest);
653
654    test_bignum! {
655        function: <utest>::next_power_of_two(a: utest),
656        skip: debug_skip!(a.checked_next_power_of_two().is_none())
657    }
658    test_bignum! {
659        function: <utest>::is_power_of_two(a: utest)
660    }
661    test_bignum! {
662        function: <utest>::cast_signed(a: utest)
663    }
664
665    #[test]
666    fn digits() {
667        let a = UTEST::MAX;
668        let digits = *a.digits();
669        assert_eq!(a, UTEST::from_digits(digits));
670    }
671
672    #[test]
673    fn bit() {
674        let u = UTEST::from(0b001010100101010101u64);
675        assert!(u.bit(0));
676        assert!(!u.bit(1));
677        // assert!(!u.bit(17));
678        // assert!(!u.bit(16));
679        assert!(u.bit(15));
680    }
681
682    #[test]
683    fn is_zero() {
684        assert!(UTEST::MIN.is_zero());
685        assert!(!UTEST::MAX.is_zero());
686        assert!(!UTEST::ONE.is_zero());
687    }
688
689    #[test]
690    fn is_one() {
691        assert!(UTEST::ONE.is_one());
692        assert!(!UTEST::MAX.is_one());
693        assert!(!UTEST::ZERO.is_one());
694        let mut digits = *super::BUint::<2>::MAX.digits();
695        digits[0] = 1;
696        let b = super::BUint::<2>::from_digits(digits);
697        assert!(!b.is_one());
698    }
699
700    #[test]
701    fn bits() {
702        let u = UTEST::from(0b1010100101010101u128);
703        assert_eq!(u.bits(), 16);
704
705        let u = UTEST::power_of_two(7);
706        assert_eq!(u.bits(), 8);
707    }
708
709    #[test]
710    fn default() {
711        assert_eq!(UTEST::default(), utest::default().into());
712    }
713
714    #[test]
715    fn sum() {
716        let v = vec![&UTEST::ZERO, &UTEST::ONE, &UTEST::TWO, &UTEST::THREE, &UTEST::FOUR];
717        assert_eq!(UTEST::TEN, v.iter().copied().sum());
718        assert_eq!(UTEST::TEN, v.into_iter().sum());
719    }
720
721    #[test]
722    fn product() {
723        let v = vec![&UTEST::ONE, &UTEST::TWO, &UTEST::THREE];
724        assert_eq!(UTEST::SIX, v.iter().copied().sum());
725        assert_eq!(UTEST::SIX, v.into_iter().sum());
726    }
727}
728
729crate::macro_impl!(mod_impl);
730
731mod bigint_helpers;
732pub mod cast;
733mod checked;
734mod cmp;
735mod const_trait_fillers;
736mod consts;
737mod convert;
738mod div;
739mod endian;
740mod fmt;
741mod mask;
742mod mul;
743#[cfg(feature = "numtraits")]
744mod numtraits;
745mod ops;
746mod overflowing;
747mod radix;
748mod saturating;
749mod strict;
750mod unchecked;
751mod wrapping;