num_traits/ops/
mul_add.rs

1/// Fused multiply-add. Computes `(self * a) + b` with only one rounding
2/// error, yielding a more accurate result than an unfused multiply-add.
3///
4/// Using `mul_add` can be more performant than an unfused multiply-add if
5/// the target architecture has a dedicated `fma` CPU instruction.
6///
7/// Note that `A` and `B` are `Self` by default, but this is not mandatory.
8///
9/// # Example
10///
11/// ```
12/// use std::f32;
13///
14/// let m = 10.0_f32;
15/// let x = 4.0_f32;
16/// let b = 60.0_f32;
17///
18/// // 100.0
19/// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
20///
21/// assert!(abs_difference <= 100.0 * f32::EPSILON);
22/// ```
23pub trait MulAdd<A = Self, B = Self> {
24    /// The resulting type after applying the fused multiply-add.
25    type Output;
26
27    /// Performs the fused multiply-add operation `(self * a) + b`
28    fn mul_add(self, a: A, b: B) -> Self::Output;
29}
30
31/// The fused multiply-add assignment operation `*self = (*self * a) + b`
32pub trait MulAddAssign<A = Self, B = Self> {
33    /// Performs the fused multiply-add assignment operation `*self = (*self * a) + b`
34    fn mul_add_assign(&mut self, a: A, b: B);
35}
36
37#[cfg(any(feature = "std", feature = "libm"))]
38impl MulAdd<f32, f32> for f32 {
39    type Output = Self;
40
41    #[inline]
42    fn mul_add(self, a: Self, b: Self) -> Self::Output {
43        <Self as crate::Float>::mul_add(self, a, b)
44    }
45}
46
47#[cfg(any(feature = "std", feature = "libm"))]
48impl MulAdd<f64, f64> for f64 {
49    type Output = Self;
50
51    #[inline]
52    fn mul_add(self, a: Self, b: Self) -> Self::Output {
53        <Self as crate::Float>::mul_add(self, a, b)
54    }
55}
56
57macro_rules! mul_add_impl {
58    ($trait_name:ident for $($t:ty)*) => {$(
59        impl $trait_name for $t {
60            type Output = Self;
61
62            #[inline]
63            fn mul_add(self, a: Self, b: Self) -> Self::Output {
64                (self * a) + b
65            }
66        }
67    )*}
68}
69
70mul_add_impl!(MulAdd for isize i8 i16 i32 i64 i128);
71mul_add_impl!(MulAdd for usize u8 u16 u32 u64 u128);
72
73#[cfg(any(feature = "std", feature = "libm"))]
74impl MulAddAssign<f32, f32> for f32 {
75    #[inline]
76    fn mul_add_assign(&mut self, a: Self, b: Self) {
77        *self = <Self as crate::Float>::mul_add(*self, a, b)
78    }
79}
80
81#[cfg(any(feature = "std", feature = "libm"))]
82impl MulAddAssign<f64, f64> for f64 {
83    #[inline]
84    fn mul_add_assign(&mut self, a: Self, b: Self) {
85        *self = <Self as crate::Float>::mul_add(*self, a, b)
86    }
87}
88
89macro_rules! mul_add_assign_impl {
90    ($trait_name:ident for $($t:ty)*) => {$(
91        impl $trait_name for $t {
92            #[inline]
93            fn mul_add_assign(&mut self, a: Self, b: Self) {
94                *self = (*self * a) + b
95            }
96        }
97    )*}
98}
99
100mul_add_assign_impl!(MulAddAssign for isize i8 i16 i32 i64 i128);
101mul_add_assign_impl!(MulAddAssign for usize u8 u16 u32 u64 u128);
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn mul_add_integer() {
109        macro_rules! test_mul_add {
110            ($($t:ident)+) => {
111                $(
112                    {
113                        let m: $t = 2;
114                        let x: $t = 3;
115                        let b: $t = 4;
116
117                        assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
118                    }
119                )+
120            };
121        }
122
123        test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
124    }
125
126    #[test]
127    #[cfg(feature = "std")]
128    fn mul_add_float() {
129        macro_rules! test_mul_add {
130            ($($t:ident)+) => {
131                $(
132                    {
133                        use core::$t;
134
135                        let m: $t = 12.0;
136                        let x: $t = 3.4;
137                        let b: $t = 5.6;
138
139                        let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
140
141                        assert!(abs_difference <= 46.4 * $t::EPSILON);
142                    }
143                )+
144            };
145        }
146
147        test_mul_add!(f32 f64);
148    }
149}