libm/math/
sqrt.rs

1/* origin: FreeBSD /usr/src/lib/msun/src/e_sqrt.c */
2/*
3 * ====================================================
4 * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
5 *
6 * Developed at SunSoft, a Sun Microsystems, Inc. business.
7 * Permission to use, copy, modify, and distribute this
8 * software is freely granted, provided that this notice
9 * is preserved.
10 * ====================================================
11 */
12/* sqrt(x)
13 * Return correctly rounded sqrt.
14 *           ------------------------------------------
15 *           |  Use the hardware sqrt if you have one |
16 *           ------------------------------------------
17 * Method:
18 *   Bit by bit method using integer arithmetic. (Slow, but portable)
19 *   1. Normalization
20 *      Scale x to y in [1,4) with even powers of 2:
21 *      find an integer k such that  1 <= (y=x*2^(2k)) < 4, then
22 *              sqrt(x) = 2^k * sqrt(y)
23 *   2. Bit by bit computation
24 *      Let q  = sqrt(y) truncated to i bit after binary point (q = 1),
25 *           i                                                   0
26 *                                     i+1         2
27 *          s  = 2*q , and      y  =  2   * ( y - q  ).         (1)
28 *           i      i            i                 i
29 *
30 *      To compute q    from q , one checks whether
31 *                  i+1       i
32 *
33 *                            -(i+1) 2
34 *                      (q + 2      ) <= y.                     (2)
35 *                        i
36 *                                                            -(i+1)
37 *      If (2) is false, then q   = q ; otherwise q   = q  + 2      .
38 *                             i+1   i             i+1   i
39 *
40 *      With some algebraic manipulation, it is not difficult to see
41 *      that (2) is equivalent to
42 *                             -(i+1)
43 *                      s  +  2       <= y                      (3)
44 *                       i                i
45 *
46 *      The advantage of (3) is that s  and y  can be computed by
47 *                                    i      i
48 *      the following recurrence formula:
49 *          if (3) is false
50 *
51 *          s     =  s  ,       y    = y   ;                    (4)
52 *           i+1      i          i+1    i
53 *
54 *          otherwise,
55 *                         -i                     -(i+1)
56 *          s     =  s  + 2  ,  y    = y  -  s  - 2             (5)
57 *           i+1      i          i+1    i     i
58 *
59 *      One may easily use induction to prove (4) and (5).
60 *      Note. Since the left hand side of (3) contain only i+2 bits,
61 *            it does not necessary to do a full (53-bit) comparison
62 *            in (3).
63 *   3. Final rounding
64 *      After generating the 53 bits result, we compute one more bit.
65 *      Together with the remainder, we can decide whether the
66 *      result is exact, bigger than 1/2ulp, or less than 1/2ulp
67 *      (it will never equal to 1/2ulp).
68 *      The rounding mode can be detected by checking whether
69 *      huge + tiny is equal to huge, and whether huge - tiny is
70 *      equal to huge for some floating point number "huge" and "tiny".
71 *
72 * Special cases:
73 *      sqrt(+-0) = +-0         ... exact
74 *      sqrt(inf) = inf
75 *      sqrt(-ve) = NaN         ... with invalid signal
76 *      sqrt(NaN) = NaN         ... with invalid signal for signaling NaN
77 */
78
79use core::f64;
80
81#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
82pub fn sqrt(x: f64) -> f64 {
83    // On wasm32 we know that LLVM's intrinsic will compile to an optimized
84    // `f64.sqrt` native instruction, so we can leverage this for both code size
85    // and speed.
86    llvm_intrinsically_optimized! {
87        #[cfg(target_arch = "wasm32")] {
88            return if x < 0.0 {
89                f64::NAN
90            } else {
91                unsafe { ::core::intrinsics::sqrtf64(x) }
92            }
93        }
94    }
95    #[cfg(target_feature = "sse2")]
96    {
97        // Note: This path is unlikely since LLVM will usually have already
98        // optimized sqrt calls into hardware instructions if sse2 is available,
99        // but if someone does end up here they'll apprected the speed increase.
100        #[cfg(target_arch = "x86")]
101        use core::arch::x86::*;
102        #[cfg(target_arch = "x86_64")]
103        use core::arch::x86_64::*;
104        unsafe {
105            let m = _mm_set_sd(x);
106            let m_sqrt = _mm_sqrt_pd(m);
107            _mm_cvtsd_f64(m_sqrt)
108        }
109    }
110    #[cfg(not(target_feature = "sse2"))]
111    {
112        use core::num::Wrapping;
113
114        const TINY: f64 = 1.0e-300;
115
116        let mut z: f64;
117        let sign: Wrapping<u32> = Wrapping(0x80000000);
118        let mut ix0: i32;
119        let mut s0: i32;
120        let mut q: i32;
121        let mut m: i32;
122        let mut t: i32;
123        let mut i: i32;
124        let mut r: Wrapping<u32>;
125        let mut t1: Wrapping<u32>;
126        let mut s1: Wrapping<u32>;
127        let mut ix1: Wrapping<u32>;
128        let mut q1: Wrapping<u32>;
129
130        ix0 = (x.to_bits() >> 32) as i32;
131        ix1 = Wrapping(x.to_bits() as u32);
132
133        /* take care of Inf and NaN */
134        if (ix0 & 0x7ff00000) == 0x7ff00000 {
135            return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
136        }
137        /* take care of zero */
138        if ix0 <= 0 {
139            if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
140                return x; /* sqrt(+-0) = +-0 */
141            }
142            if ix0 < 0 {
143                return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
144            }
145        }
146        /* normalize x */
147        m = ix0 >> 20;
148        if m == 0 {
149            /* subnormal x */
150            while ix0 == 0 {
151                m -= 21;
152                ix0 |= (ix1 >> 11).0 as i32;
153                ix1 <<= 21;
154            }
155            i = 0;
156            while (ix0 & 0x00100000) == 0 {
157                i += 1;
158                ix0 <<= 1;
159            }
160            m -= i - 1;
161            ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
162            ix1 = ix1 << i as usize;
163        }
164        m -= 1023; /* unbias exponent */
165        ix0 = (ix0 & 0x000fffff) | 0x00100000;
166        if (m & 1) == 1 {
167            /* odd m, double x to make it even */
168            ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
169            ix1 += ix1;
170        }
171        m >>= 1; /* m = [m/2] */
172
173        /* generate sqrt(x) bit by bit */
174        ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
175        ix1 += ix1;
176        q = 0; /* [q,q1] = sqrt(x) */
177        q1 = Wrapping(0);
178        s0 = 0;
179        s1 = Wrapping(0);
180        r = Wrapping(0x00200000); /* r = moving bit from right to left */
181
182        while r != Wrapping(0) {
183            t = s0 + r.0 as i32;
184            if t <= ix0 {
185                s0 = t + r.0 as i32;
186                ix0 -= t;
187                q += r.0 as i32;
188            }
189            ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
190            ix1 += ix1;
191            r >>= 1;
192        }
193
194        r = sign;
195        while r != Wrapping(0) {
196            t1 = s1 + r;
197            t = s0;
198            if t < ix0 || (t == ix0 && t1 <= ix1) {
199                s1 = t1 + r;
200                if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
201                    s0 += 1;
202                }
203                ix0 -= t;
204                if ix1 < t1 {
205                    ix0 -= 1;
206                }
207                ix1 -= t1;
208                q1 += r;
209            }
210            ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
211            ix1 += ix1;
212            r >>= 1;
213        }
214
215        /* use floating add to find out rounding direction */
216        if (ix0 as u32 | ix1.0) != 0 {
217            z = 1.0 - TINY; /* raise inexact flag */
218            if z >= 1.0 {
219                z = 1.0 + TINY;
220                if q1.0 == 0xffffffff {
221                    q1 = Wrapping(0);
222                    q += 1;
223                } else if z > 1.0 {
224                    if q1.0 == 0xfffffffe {
225                        q += 1;
226                    }
227                    q1 += Wrapping(2);
228                } else {
229                    q1 += q1 & Wrapping(1);
230                }
231            }
232        }
233        ix0 = (q >> 1) + 0x3fe00000;
234        ix1 = q1 >> 1;
235        if (q & 1) == 1 {
236            ix1 |= sign;
237        }
238        ix0 += m << 20;
239        f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use core::f64::*;
247
248    #[test]
249    fn sanity_check() {
250        assert_eq!(sqrt(100.0), 10.0);
251        assert_eq!(sqrt(4.0), 2.0);
252    }
253
254    /// The spec: https://en.cppreference.com/w/cpp/numeric/math/sqrt
255    #[test]
256    fn spec_tests() {
257        // Not Asserted: FE_INVALID exception is raised if argument is negative.
258        assert!(sqrt(-1.0).is_nan());
259        assert!(sqrt(NAN).is_nan());
260        for f in [0.0, -0.0, INFINITY].iter().copied() {
261            assert_eq!(sqrt(f), f);
262        }
263    }
264
265    #[test]
266    fn conformance_tests() {
267        let values = [3.14159265359, 10000.0, f64::from_bits(0x0000000f), INFINITY];
268        let results = [
269            4610661241675116657u64,
270            4636737291354636288u64,
271            2197470602079456986u64,
272            9218868437227405312u64,
273        ];
274
275        for i in 0..values.len() {
276            let bits = f64::to_bits(sqrt(values[i]));
277            assert_eq!(results[i], bits);
278        }
279    }
280}