rand/distributions/
utils.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Math helper functions
10
11#[cfg(feature = "simd_support")] use packed_simd::*;
12
13
14pub(crate) trait WideningMultiply<RHS = Self> {
15    type Output;
16
17    fn wmul(self, x: RHS) -> Self::Output;
18}
19
20macro_rules! wmul_impl {
21    ($ty:ty, $wide:ty, $shift:expr) => {
22        impl WideningMultiply for $ty {
23            type Output = ($ty, $ty);
24
25            #[inline(always)]
26            fn wmul(self, x: $ty) -> Self::Output {
27                let tmp = (self as $wide) * (x as $wide);
28                ((tmp >> $shift) as $ty, tmp as $ty)
29            }
30        }
31    };
32
33    // simd bulk implementation
34    ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
35        $(
36            impl WideningMultiply for $ty {
37                type Output = ($ty, $ty);
38
39                #[inline(always)]
40                fn wmul(self, x: $ty) -> Self::Output {
41                    // For supported vectors, this should compile to a couple
42                    // supported multiply & swizzle instructions (no actual
43                    // casting).
44                    // TODO: optimize
45                    let y: $wide = self.cast();
46                    let x: $wide = x.cast();
47                    let tmp = y * x;
48                    let hi: $ty = (tmp >> $shift).cast();
49                    let lo: $ty = tmp.cast();
50                    (hi, lo)
51                }
52            }
53        )+
54    };
55}
56wmul_impl! { u8, u16, 8 }
57wmul_impl! { u16, u32, 16 }
58wmul_impl! { u32, u64, 32 }
59#[cfg(not(target_os = "emscripten"))]
60wmul_impl! { u64, u128, 64 }
61
62// This code is a translation of the __mulddi3 function in LLVM's
63// compiler-rt. It is an optimised variant of the common method
64// `(a + b) * (c + d) = ac + ad + bc + bd`.
65//
66// For some reason LLVM can optimise the C version very well, but
67// keeps shuffling registers in this Rust translation.
68macro_rules! wmul_impl_large {
69    ($ty:ty, $half:expr) => {
70        impl WideningMultiply for $ty {
71            type Output = ($ty, $ty);
72
73            #[inline(always)]
74            fn wmul(self, b: $ty) -> Self::Output {
75                const LOWER_MASK: $ty = !0 >> $half;
76                let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
77                let mut t = low >> $half;
78                low &= LOWER_MASK;
79                t += (self >> $half).wrapping_mul(b & LOWER_MASK);
80                low += (t & LOWER_MASK) << $half;
81                let mut high = t >> $half;
82                t = low >> $half;
83                low &= LOWER_MASK;
84                t += (b >> $half).wrapping_mul(self & LOWER_MASK);
85                low += (t & LOWER_MASK) << $half;
86                high += t >> $half;
87                high += (self >> $half).wrapping_mul(b >> $half);
88
89                (high, low)
90            }
91        }
92    };
93
94    // simd bulk implementation
95    (($($ty:ty,)+) $scalar:ty, $half:expr) => {
96        $(
97            impl WideningMultiply for $ty {
98                type Output = ($ty, $ty);
99
100                #[inline(always)]
101                fn wmul(self, b: $ty) -> Self::Output {
102                    // needs wrapping multiplication
103                    const LOWER_MASK: $scalar = !0 >> $half;
104                    let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
105                    let mut t = low >> $half;
106                    low &= LOWER_MASK;
107                    t += (self >> $half) * (b & LOWER_MASK);
108                    low += (t & LOWER_MASK) << $half;
109                    let mut high = t >> $half;
110                    t = low >> $half;
111                    low &= LOWER_MASK;
112                    t += (b >> $half) * (self & LOWER_MASK);
113                    low += (t & LOWER_MASK) << $half;
114                    high += t >> $half;
115                    high += (self >> $half) * (b >> $half);
116
117                    (high, low)
118                }
119            }
120        )+
121    };
122}
123#[cfg(target_os = "emscripten")]
124wmul_impl_large! { u64, 32 }
125#[cfg(not(target_os = "emscripten"))]
126wmul_impl_large! { u128, 64 }
127
128macro_rules! wmul_impl_usize {
129    ($ty:ty) => {
130        impl WideningMultiply for usize {
131            type Output = (usize, usize);
132
133            #[inline(always)]
134            fn wmul(self, x: usize) -> Self::Output {
135                let (high, low) = (self as $ty).wmul(x as $ty);
136                (high as usize, low as usize)
137            }
138        }
139    };
140}
141#[cfg(target_pointer_width = "32")]
142wmul_impl_usize! { u32 }
143#[cfg(target_pointer_width = "64")]
144wmul_impl_usize! { u64 }
145
146#[cfg(feature = "simd_support")]
147mod simd_wmul {
148    use super::*;
149    #[cfg(target_arch = "x86")] use core::arch::x86::*;
150    #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
151
152    wmul_impl! {
153        (u8x2, u16x2),
154        (u8x4, u16x4),
155        (u8x8, u16x8),
156        (u8x16, u16x16),
157        (u8x32, u16x32),,
158        8
159    }
160
161    wmul_impl! { (u16x2, u32x2),, 16 }
162    wmul_impl! { (u16x4, u32x4),, 16 }
163    #[cfg(not(target_feature = "sse2"))]
164    wmul_impl! { (u16x8, u32x8),, 16 }
165    #[cfg(not(target_feature = "avx2"))]
166    wmul_impl! { (u16x16, u32x16),, 16 }
167
168    // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
169    // means `wmul` can be implemented with only two instructions.
170    #[allow(unused_macros)]
171    macro_rules! wmul_impl_16 {
172        ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => {
173            impl WideningMultiply for $ty {
174                type Output = ($ty, $ty);
175
176                #[inline(always)]
177                fn wmul(self, x: $ty) -> Self::Output {
178                    let b = $intrinsic::from_bits(x);
179                    let a = $intrinsic::from_bits(self);
180                    let hi = $ty::from_bits(unsafe { $mulhi(a, b) });
181                    let lo = $ty::from_bits(unsafe { $mullo(a, b) });
182                    (hi, lo)
183                }
184            }
185        };
186    }
187
188    #[cfg(target_feature = "sse2")]
189    wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
190    #[cfg(target_feature = "avx2")]
191    wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
192    // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
193    // cannot use the same implementation.
194
195    wmul_impl! {
196        (u32x2, u64x2),
197        (u32x4, u64x4),
198        (u32x8, u64x8),,
199        32
200    }
201
202    // TODO: optimize, this seems to seriously slow things down
203    wmul_impl_large! { (u8x64,) u8, 4 }
204    wmul_impl_large! { (u16x32,) u16, 8 }
205    wmul_impl_large! { (u32x16,) u32, 16 }
206    wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
207}
208
209/// Helper trait when dealing with scalar and SIMD floating point types.
210pub(crate) trait FloatSIMDUtils {
211    // `PartialOrd` for vectors compares lexicographically. We want to compare all
212    // the individual SIMD lanes instead, and get the combined result over all
213    // lanes. This is possible using something like `a.lt(b).all()`, but we
214    // implement it as a trait so we can write the same code for `f32` and `f64`.
215    // Only the comparison functions we need are implemented.
216    fn all_lt(self, other: Self) -> bool;
217    fn all_le(self, other: Self) -> bool;
218    fn all_finite(self) -> bool;
219
220    type Mask;
221    fn finite_mask(self) -> Self::Mask;
222    fn gt_mask(self, other: Self) -> Self::Mask;
223    fn ge_mask(self, other: Self) -> Self::Mask;
224
225    // Decrease all lanes where the mask is `true` to the next lower value
226    // representable by the floating-point type. At least one of the lanes
227    // must be set.
228    fn decrease_masked(self, mask: Self::Mask) -> Self;
229
230    // Convert from int value. Conversion is done while retaining the numerical
231    // value, not by retaining the binary representation.
232    type UInt;
233    fn cast_from_int(i: Self::UInt) -> Self;
234}
235
236/// Implement functions available in std builds but missing from core primitives
237#[cfg(not(std))]
238// False positive: We are following `std` here.
239#[allow(clippy::wrong_self_convention)]
240pub(crate) trait Float: Sized {
241    fn is_nan(self) -> bool;
242    fn is_infinite(self) -> bool;
243    fn is_finite(self) -> bool;
244}
245
246/// Implement functions on f32/f64 to give them APIs similar to SIMD types
247pub(crate) trait FloatAsSIMD: Sized {
248    #[inline(always)]
249    fn lanes() -> usize {
250        1
251    }
252    #[inline(always)]
253    fn splat(scalar: Self) -> Self {
254        scalar
255    }
256    #[inline(always)]
257    fn extract(self, index: usize) -> Self {
258        debug_assert_eq!(index, 0);
259        self
260    }
261    #[inline(always)]
262    fn replace(self, index: usize, new_value: Self) -> Self {
263        debug_assert_eq!(index, 0);
264        new_value
265    }
266}
267
268pub(crate) trait BoolAsSIMD: Sized {
269    fn any(self) -> bool;
270    fn all(self) -> bool;
271    fn none(self) -> bool;
272}
273
274impl BoolAsSIMD for bool {
275    #[inline(always)]
276    fn any(self) -> bool {
277        self
278    }
279
280    #[inline(always)]
281    fn all(self) -> bool {
282        self
283    }
284
285    #[inline(always)]
286    fn none(self) -> bool {
287        !self
288    }
289}
290
291macro_rules! scalar_float_impl {
292    ($ty:ident, $uty:ident) => {
293        #[cfg(not(std))]
294        impl Float for $ty {
295            #[inline]
296            fn is_nan(self) -> bool {
297                self != self
298            }
299
300            #[inline]
301            fn is_infinite(self) -> bool {
302                self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
303            }
304
305            #[inline]
306            fn is_finite(self) -> bool {
307                !(self.is_nan() || self.is_infinite())
308            }
309        }
310
311        impl FloatSIMDUtils for $ty {
312            type Mask = bool;
313            type UInt = $uty;
314
315            #[inline(always)]
316            fn all_lt(self, other: Self) -> bool {
317                self < other
318            }
319
320            #[inline(always)]
321            fn all_le(self, other: Self) -> bool {
322                self <= other
323            }
324
325            #[inline(always)]
326            fn all_finite(self) -> bool {
327                self.is_finite()
328            }
329
330            #[inline(always)]
331            fn finite_mask(self) -> Self::Mask {
332                self.is_finite()
333            }
334
335            #[inline(always)]
336            fn gt_mask(self, other: Self) -> Self::Mask {
337                self > other
338            }
339
340            #[inline(always)]
341            fn ge_mask(self, other: Self) -> Self::Mask {
342                self >= other
343            }
344
345            #[inline(always)]
346            fn decrease_masked(self, mask: Self::Mask) -> Self {
347                debug_assert!(mask, "At least one lane must be set");
348                <$ty>::from_bits(self.to_bits() - 1)
349            }
350
351            #[inline]
352            fn cast_from_int(i: Self::UInt) -> Self {
353                i as $ty
354            }
355        }
356
357        impl FloatAsSIMD for $ty {}
358    };
359}
360
361scalar_float_impl!(f32, u32);
362scalar_float_impl!(f64, u64);
363
364
365#[cfg(feature = "simd_support")]
366macro_rules! simd_impl {
367    ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => {
368        impl FloatSIMDUtils for $ty {
369            type Mask = $mty;
370            type UInt = $uty;
371
372            #[inline(always)]
373            fn all_lt(self, other: Self) -> bool {
374                self.lt(other).all()
375            }
376
377            #[inline(always)]
378            fn all_le(self, other: Self) -> bool {
379                self.le(other).all()
380            }
381
382            #[inline(always)]
383            fn all_finite(self) -> bool {
384                self.finite_mask().all()
385            }
386
387            #[inline(always)]
388            fn finite_mask(self) -> Self::Mask {
389                // This can possibly be done faster by checking bit patterns
390                let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY);
391                let pos_inf = $ty::splat(::core::$f_scalar::INFINITY);
392                self.gt(neg_inf) & self.lt(pos_inf)
393            }
394
395            #[inline(always)]
396            fn gt_mask(self, other: Self) -> Self::Mask {
397                self.gt(other)
398            }
399
400            #[inline(always)]
401            fn ge_mask(self, other: Self) -> Self::Mask {
402                self.ge(other)
403            }
404
405            #[inline(always)]
406            fn decrease_masked(self, mask: Self::Mask) -> Self {
407                // Casting a mask into ints will produce all bits set for
408                // true, and 0 for false. Adding that to the binary
409                // representation of a float means subtracting one from
410                // the binary representation, resulting in the next lower
411                // value representable by $ty. This works even when the
412                // current value is infinity.
413                debug_assert!(mask.any(), "At least one lane must be set");
414                <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask))
415            }
416
417            #[inline]
418            fn cast_from_int(i: Self::UInt) -> Self {
419                i.cast()
420            }
421        }
422    };
423}
424
425#[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
426#[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
427#[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
428#[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
429#[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
430#[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
431#[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }