1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
1718//! N-digit division
19//!
20//! Implementation heavily inspired by [uint]
21//!
22//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844
2324/// Unsigned, little-endian, n-digit division with remainder
25///
26/// # Panics
27///
28/// Panics if divisor is zero
29pub fn div_rem<const N: usize>(numerator: &[u64; N], divisor: &[u64; N]) -> ([u64; N], [u64; N]) {
30let numerator_bits = bits(numerator);
31let divisor_bits = bits(divisor);
32assert_ne!(divisor_bits, 0, "division by zero");
3334if numerator_bits < divisor_bits {
35return ([0; N], *numerator);
36 }
3738if divisor_bits <= 64 {
39return div_rem_small(numerator, divisor[0]);
40 }
4142let numerator_words = (numerator_bits + 63) / 64;
43let divisor_words = (divisor_bits + 63) / 64;
44let n = divisor_words;
45let m = numerator_words - divisor_words;
4647 div_rem_knuth(numerator, divisor, n, m)
48}
4950/// Return the least number of bits needed to represent the number
51fn bits(arr: &[u64]) -> usize {
52for (idx, v) in arr.iter().enumerate().rev() {
53if *v > 0 {
54return 64 - v.leading_zeros() as usize + 64 * idx;
55 }
56 }
570
58}
5960/// Division of numerator by a u64 divisor
61fn div_rem_small<const N: usize>(numerator: &[u64; N], divisor: u64) -> ([u64; N], [u64; N]) {
62let mut rem = 0u64;
63let mut numerator = *numerator;
64 numerator.iter_mut().rev().for_each(|d| {
65let (q, r) = div_rem_word(rem, *d, divisor);
66*d = q;
67 rem = r;
68 });
6970let mut rem_padded = [0; N];
71 rem_padded[0] = rem;
72 (numerator, rem_padded)
73}
7475/// Use Knuth Algorithm D to compute `numerator / divisor` returning the
76/// quotient and remainder
77///
78/// `n` is the number of non-zero 64-bit words in `divisor`
79/// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and
80/// therefore the number of words in the quotient
81///
82/// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html)
83fn div_rem_knuth<const N: usize>(
84 numerator: &[u64; N],
85 divisor: &[u64; N],
86 n: usize,
87 m: usize,
88) -> ([u64; N], [u64; N]) {
89assert!(n + m <= N);
9091// The algorithm works by incrementally generating guesses `q_hat`, for the next digit
92 // of the quotient, starting from the most significant digit.
93 //
94 // This relies on the property that for any `q_hat` where
95 //
96 // (q_hat << (j * 64)) * divisor <= numerator`
97 //
98 // We can set
99 //
100 // q += q_hat << (j * 64)
101 // numerator -= (q_hat << (j * 64)) * divisor
102 //
103 // And then iterate until `numerator < divisor`
104105 // We normalize the divisor so that the highest bit in the highest digit of the
106 // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from
107 // the correct value for q[j]
108let shift = divisor[n - 1].leading_zeros();
109// As the shift is computed based on leading zeros, don't need to perform full_shl
110let divisor = shl_word(divisor, shift);
111// numerator may have fewer leading zeros than divisor, so must add another digit
112let mut numerator = full_shl(numerator, shift);
113114// The two most significant digits of the divisor
115let b0 = divisor[n - 1];
116let b1 = divisor[n - 2];
117118let mut q = [0; N];
119120for j in (0..=m).rev() {
121let a0 = numerator[j + n];
122let a1 = numerator[j + n - 1];
123124let mut q_hat = if a0 < b0 {
125// The first estimate is [a1, a0] / b0, it may be too large by at most 2
126let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0);
127128// r_hat = [a1, a0] - q_hat * b0
129 //
130 // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0]
131 // which can only be less or equal to the current q_hat
132 //
133 // q_hat is too large if:
134 // [a2,a1,a0] < q_hat * [b1,b0]
135 // [a2,r_hat] < q_hat * b1
136let a2 = numerator[j + n - 2];
137loop {
138let r = u128::from(q_hat) * u128::from(b1);
139let (lo, hi) = (r as u64, (r >> 64) as u64);
140if (hi, lo) <= (r_hat, a2) {
141break;
142 }
143144 q_hat -= 1;
145let (new_r_hat, overflow) = r_hat.overflowing_add(b0);
146 r_hat = new_r_hat;
147148if overflow {
149break;
150 }
151 }
152 q_hat
153 } else {
154 u64::MAX
155 };
156157// q_hat is now either the correct quotient digit, or in rare cases 1 too large
158159 // Compute numerator -= (q_hat * divisor) << (j * 64)
160let q_hat_v = full_mul_u64(&divisor, q_hat);
161let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]);
162163// If underflow, q_hat was too large by 1
164if c {
165// Reduce q_hat by 1
166q_hat -= 1;
167168// Add back one multiple of divisor
169let c = add_assign(&mut numerator[j..], &divisor[..n]);
170 numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c));
171 }
172173// q_hat is the correct value for q[j]
174q[j] = q_hat;
175 }
176177// The remainder is what is left in numerator, with the initial normalization shl reversed
178let remainder = full_shr(&numerator, shift);
179 (q, remainder)
180}
181182/// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder
183///
184/// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit
185/// into a 64-bit integer
186fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) {
187debug_assert!(hi < divisor);
188debug_assert_ne!(divisor, 0);
189190// LLVM fails to use the div instruction as it is not able to prove
191 // that hi < divisor, and therefore the result will fit into 64-bits
192#[cfg(all(target_arch = "x86_64", not(miri)))]
193unsafe {
194let mut quot = lo;
195let mut rem = hi;
196std::arch::asm!(
197"div {divisor}",
198 divisor = in(reg) divisor,
199 inout("rax") quot,
200 inout("rdx") rem,
201 options(pure, nomem, nostack)
202 );
203 (quot, rem)
204 }
205#[cfg(any(not(target_arch = "x86_64"), miri))]
206{
207let x = (u128::from(hi) << 64) + u128::from(lo);
208let y = u128::from(divisor);
209 ((x / y) as u64, (x % y) as u64)
210 }
211}
212213/// Perform `a += b`
214fn add_assign(a: &mut [u64], b: &[u64]) -> bool {
215 binop_slice(a, b, u64::overflowing_add)
216}
217218/// Perform `a -= b`
219fn sub_assign(a: &mut [u64], b: &[u64]) -> bool {
220 binop_slice(a, b, u64::overflowing_sub)
221}
222223/// Converts an overflowing binary operation on scalars to one on slices
224fn binop_slice(a: &mut [u64], b: &[u64], binop: impl Fn(u64, u64) -> (u64, bool) + Copy) -> bool {
225let mut c = false;
226 a.iter_mut().zip(b.iter()).for_each(|(x, y)| {
227let (res1, overflow1) = y.overflowing_add(u64::from(c));
228let (res2, overflow2) = binop(*x, res1);
229*x = res2;
230 c = overflow1 || overflow2;
231 });
232 c
233}
234235/// Widening multiplication of an N-digit array with a u64
236fn full_mul_u64<const N: usize>(a: &[u64; N], b: u64) -> ArrayPlusOne<u64, N> {
237let mut carry = 0;
238let mut out = [0; N];
239 out.iter_mut().zip(a).for_each(|(o, v)| {
240let r = *v as u128 * b as u128 + carry as u128;
241*o = r as u64;
242 carry = (r >> 64) as u64;
243 });
244 ArrayPlusOne(out, carry)
245}
246247/// Left shift of an N-digit array by at most 63 bits
248fn shl_word<const N: usize>(v: &[u64; N], shift: u32) -> [u64; N] {
249 full_shl(v, shift).0
250}
251252/// Widening left shift of an N-digit array by at most 63 bits
253fn full_shl<const N: usize>(v: &[u64; N], shift: u32) -> ArrayPlusOne<u64, N> {
254debug_assert!(shift < 64);
255if shift == 0 {
256return ArrayPlusOne(*v, 0);
257 }
258let mut out = [0u64; N];
259 out[0] = v[0] << shift;
260for i in 1..N {
261 out[i] = v[i - 1] >> (64 - shift) | v[i] << shift
262 }
263let carry = v[N - 1] >> (64 - shift);
264 ArrayPlusOne(out, carry)
265}
266267/// Narrowing right shift of an (N+1)-digit array by at most 63 bits
268fn full_shr<const N: usize>(a: &ArrayPlusOne<u64, N>, shift: u32) -> [u64; N] {
269debug_assert!(shift < 64);
270if shift == 0 {
271return a.0;
272 }
273let mut out = [0; N];
274for i in 0..N - 1 {
275 out[i] = a[i] >> shift | a[i + 1] << (64 - shift)
276 }
277 out[N - 1] = a[N - 1] >> shift;
278 out
279}
280281/// An array of N + 1 elements
282///
283/// This is a hack around lack of support for const arithmetic
284#[repr(C)]
285struct ArrayPlusOne<T, const N: usize>([T; N], T);
286287impl<T, const N: usize> std::ops::Deref for ArrayPlusOne<T, N> {
288type Target = [T];
289290#[inline]
291fn deref(&self) -> &Self::Target {
292let x = self as *const Self;
293unsafe { std::slice::from_raw_parts(x as *const T, N + 1) }
294 }
295}
296297impl<T, const N: usize> std::ops::DerefMut for ArrayPlusOne<T, N> {
298fn deref_mut(&mut self) -> &mut Self::Target {
299let x = self as *mut Self;
300unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) }
301 }
302}