num_integer/
roots.rs

1use core;
2use core::mem;
3use traits::checked_pow;
4use traits::PrimInt;
5use Integer;
6
7/// Provides methods to compute an integer's square root, cube root,
8/// and arbitrary `n`th root.
9pub trait Roots: Integer {
10    /// Returns the truncated principal `n`th root of an integer
11    /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }`
12    ///
13    /// This is solving for `r` in `rⁿ = x`, rounding toward zero.
14    /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`.
15    /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`.
16    ///
17    /// # Panics
18    ///
19    /// Panics if `n` is zero:
20    ///
21    /// ```should_panic
22    /// # use num_integer::Roots;
23    /// println!("can't compute ⁰√x : {}", 123.nth_root(0));
24    /// ```
25    ///
26    /// or if `n` is even and `self` is negative:
27    ///
28    /// ```should_panic
29    /// # use num_integer::Roots;
30    /// println!("no imaginary numbers... {}", (-1).nth_root(10));
31    /// ```
32    ///
33    /// # Examples
34    ///
35    /// ```
36    /// use num_integer::Roots;
37    ///
38    /// let x: i32 = 12345;
39    /// assert_eq!(x.nth_root(1), x);
40    /// assert_eq!(x.nth_root(2), x.sqrt());
41    /// assert_eq!(x.nth_root(3), x.cbrt());
42    /// assert_eq!(x.nth_root(4), 10);
43    /// assert_eq!(x.nth_root(13), 2);
44    /// assert_eq!(x.nth_root(14), 1);
45    /// assert_eq!(x.nth_root(std::u32::MAX), 1);
46    ///
47    /// assert_eq!(std::i32::MAX.nth_root(30), 2);
48    /// assert_eq!(std::i32::MAX.nth_root(31), 1);
49    /// assert_eq!(std::i32::MIN.nth_root(31), -2);
50    /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1);
51    ///
52    /// assert_eq!(std::u32::MAX.nth_root(31), 2);
53    /// assert_eq!(std::u32::MAX.nth_root(32), 1);
54    /// ```
55    fn nth_root(&self, n: u32) -> Self;
56
57    /// Returns the truncated principal square root of an integer -- `⌊√x⌋`
58    ///
59    /// This is solving for `r` in `r² = x`, rounding toward zero.
60    /// The result will satisfy `r² ≤ x < (r+1)²`.
61    ///
62    /// # Panics
63    ///
64    /// Panics if `self` is less than zero:
65    ///
66    /// ```should_panic
67    /// # use num_integer::Roots;
68    /// println!("no imaginary numbers... {}", (-1).sqrt());
69    /// ```
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use num_integer::Roots;
75    ///
76    /// let x: i32 = 12345;
77    /// assert_eq!((x * x).sqrt(), x);
78    /// assert_eq!((x * x + 1).sqrt(), x);
79    /// assert_eq!((x * x - 1).sqrt(), x - 1);
80    /// ```
81    #[inline]
82    fn sqrt(&self) -> Self {
83        self.nth_root(2)
84    }
85
86    /// Returns the truncated principal cube root of an integer --
87    /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }`
88    ///
89    /// This is solving for `r` in `r³ = x`, rounding toward zero.
90    /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`.
91    /// If `x` is negative, then `(r-1)³ < x ≤ r³`.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use num_integer::Roots;
97    ///
98    /// let x: i32 = 1234;
99    /// assert_eq!((x * x * x).cbrt(), x);
100    /// assert_eq!((x * x * x + 1).cbrt(), x);
101    /// assert_eq!((x * x * x - 1).cbrt(), x - 1);
102    ///
103    /// assert_eq!((-(x * x * x)).cbrt(), -x);
104    /// assert_eq!((-(x * x * x + 1)).cbrt(), -x);
105    /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1));
106    /// ```
107    #[inline]
108    fn cbrt(&self) -> Self {
109        self.nth_root(3)
110    }
111}
112
113/// Returns the truncated principal square root of an integer --
114/// see [Roots::sqrt](trait.Roots.html#method.sqrt).
115#[inline]
116pub fn sqrt<T: Roots>(x: T) -> T {
117    x.sqrt()
118}
119
120/// Returns the truncated principal cube root of an integer --
121/// see [Roots::cbrt](trait.Roots.html#method.cbrt).
122#[inline]
123pub fn cbrt<T: Roots>(x: T) -> T {
124    x.cbrt()
125}
126
127/// Returns the truncated principal `n`th root of an integer --
128/// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root).
129#[inline]
130pub fn nth_root<T: Roots>(x: T, n: u32) -> T {
131    x.nth_root(n)
132}
133
134macro_rules! signed_roots {
135    ($T:ty, $U:ty) => {
136        impl Roots for $T {
137            #[inline]
138            fn nth_root(&self, n: u32) -> Self {
139                if *self >= 0 {
140                    (*self as $U).nth_root(n) as Self
141                } else {
142                    assert!(n.is_odd(), "even roots of a negative are imaginary");
143                    -((self.wrapping_neg() as $U).nth_root(n) as Self)
144                }
145            }
146
147            #[inline]
148            fn sqrt(&self) -> Self {
149                assert!(*self >= 0, "the square root of a negative is imaginary");
150                (*self as $U).sqrt() as Self
151            }
152
153            #[inline]
154            fn cbrt(&self) -> Self {
155                if *self >= 0 {
156                    (*self as $U).cbrt() as Self
157                } else {
158                    -((self.wrapping_neg() as $U).cbrt() as Self)
159                }
160            }
161        }
162    };
163}
164
165signed_roots!(i8, u8);
166signed_roots!(i16, u16);
167signed_roots!(i32, u32);
168signed_roots!(i64, u64);
169#[cfg(has_i128)]
170signed_roots!(i128, u128);
171signed_roots!(isize, usize);
172
173#[inline]
174fn fixpoint<T, F>(mut x: T, f: F) -> T
175where
176    T: Integer + Copy,
177    F: Fn(T) -> T,
178{
179    let mut xn = f(x);
180    while x < xn {
181        x = xn;
182        xn = f(x);
183    }
184    while x > xn {
185        x = xn;
186        xn = f(x);
187    }
188    x
189}
190
191#[inline]
192fn bits<T>() -> u32 {
193    8 * mem::size_of::<T>() as u32
194}
195
196#[inline]
197fn log2<T: PrimInt>(x: T) -> u32 {
198    debug_assert!(x > T::zero());
199    bits::<T>() - 1 - x.leading_zeros()
200}
201
202macro_rules! unsigned_roots {
203    ($T:ident) => {
204        impl Roots for $T {
205            #[inline]
206            fn nth_root(&self, n: u32) -> Self {
207                fn go(a: $T, n: u32) -> $T {
208                    // Specialize small roots
209                    match n {
210                        0 => panic!("can't find a root of degree 0!"),
211                        1 => return a,
212                        2 => return a.sqrt(),
213                        3 => return a.cbrt(),
214                        _ => (),
215                    }
216
217                    // The root of values less than 2ⁿ can only be 0 or 1.
218                    if bits::<$T>() <= n || a < (1 << n) {
219                        return (a > 0) as $T;
220                    }
221
222                    if bits::<$T>() > 64 {
223                        // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
224                        return if a <= core::u64::MAX as $T {
225                            (a as u64).nth_root(n) as $T
226                        } else {
227                            let lo = (a >> n).nth_root(n) << 1;
228                            let hi = lo + 1;
229                            // 128-bit `checked_mul` also involves division, but we can't always
230                            // compute `hiⁿ` without risking overflow.  Try to avoid it though...
231                            if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
232                                match checked_pow(hi, n as usize) {
233                                    Some(x) if x <= a => hi,
234                                    _ => lo,
235                                }
236                            } else {
237                                if hi.pow(n) <= a {
238                                    hi
239                                } else {
240                                    lo
241                                }
242                            }
243                        };
244                    }
245
246                    #[cfg(feature = "std")]
247                    #[inline]
248                    fn guess(x: $T, n: u32) -> $T {
249                        // for smaller inputs, `f64` doesn't justify its cost.
250                        if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
251                            1 << ((log2(x) + n - 1) / n)
252                        } else {
253                            ((x as f64).ln() / f64::from(n)).exp() as $T
254                        }
255                    }
256
257                    #[cfg(not(feature = "std"))]
258                    #[inline]
259                    fn guess(x: $T, n: u32) -> $T {
260                        1 << ((log2(x) + n - 1) / n)
261                    }
262
263                    // https://en.wikipedia.org/wiki/Nth_root_algorithm
264                    let n1 = n - 1;
265                    let next = |x: $T| {
266                        let y = match checked_pow(x, n1 as usize) {
267                            Some(ax) => a / ax,
268                            None => 0,
269                        };
270                        (y + x * n1 as $T) / n as $T
271                    };
272                    fixpoint(guess(a, n), next)
273                }
274                go(*self, n)
275            }
276
277            #[inline]
278            fn sqrt(&self) -> Self {
279                fn go(a: $T) -> $T {
280                    if bits::<$T>() > 64 {
281                        // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
282                        return if a <= core::u64::MAX as $T {
283                            (a as u64).sqrt() as $T
284                        } else {
285                            let lo = (a >> 2u32).sqrt() << 1;
286                            let hi = lo + 1;
287                            if hi * hi <= a {
288                                hi
289                            } else {
290                                lo
291                            }
292                        };
293                    }
294
295                    if a < 4 {
296                        return (a > 0) as $T;
297                    }
298
299                    #[cfg(feature = "std")]
300                    #[inline]
301                    fn guess(x: $T) -> $T {
302                        (x as f64).sqrt() as $T
303                    }
304
305                    #[cfg(not(feature = "std"))]
306                    #[inline]
307                    fn guess(x: $T) -> $T {
308                        1 << ((log2(x) + 1) / 2)
309                    }
310
311                    // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
312                    let next = |x: $T| (a / x + x) >> 1;
313                    fixpoint(guess(a), next)
314                }
315                go(*self)
316            }
317
318            #[inline]
319            fn cbrt(&self) -> Self {
320                fn go(a: $T) -> $T {
321                    if bits::<$T>() > 64 {
322                        // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
323                        return if a <= core::u64::MAX as $T {
324                            (a as u64).cbrt() as $T
325                        } else {
326                            let lo = (a >> 3u32).cbrt() << 1;
327                            let hi = lo + 1;
328                            if hi * hi * hi <= a {
329                                hi
330                            } else {
331                                lo
332                            }
333                        };
334                    }
335
336                    if bits::<$T>() <= 32 {
337                        // Implementation based on Hacker's Delight `icbrt2`
338                        let mut x = a;
339                        let mut y2 = 0;
340                        let mut y = 0;
341                        let smax = bits::<$T>() / 3;
342                        for s in (0..smax + 1).rev() {
343                            let s = s * 3;
344                            y2 *= 4;
345                            y *= 2;
346                            let b = 3 * (y2 + y) + 1;
347                            if x >> s >= b {
348                                x -= b << s;
349                                y2 += 2 * y + 1;
350                                y += 1;
351                            }
352                        }
353                        return y;
354                    }
355
356                    if a < 8 {
357                        return (a > 0) as $T;
358                    }
359                    if a <= core::u32::MAX as $T {
360                        return (a as u32).cbrt() as $T;
361                    }
362
363                    #[cfg(feature = "std")]
364                    #[inline]
365                    fn guess(x: $T) -> $T {
366                        (x as f64).cbrt() as $T
367                    }
368
369                    #[cfg(not(feature = "std"))]
370                    #[inline]
371                    fn guess(x: $T) -> $T {
372                        1 << ((log2(x) + 2) / 3)
373                    }
374
375                    // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
376                    let next = |x: $T| (a / (x * x) + x * 2) / 3;
377                    fixpoint(guess(a), next)
378                }
379                go(*self)
380            }
381        }
382    };
383}
384
385unsigned_roots!(u8);
386unsigned_roots!(u16);
387unsigned_roots!(u32);
388unsigned_roots!(u64);
389#[cfg(has_i128)]
390unsigned_roots!(u128);
391unsigned_roots!(usize);