half/bfloat/
convert.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
pub(crate) fn f32_to_bf16(value: f32) -> u16 {
    // Convert to raw bytes
    let x = value.to_bits();

    // check for NaN
    if x & 0x7FFF_FFFFu32 > 0x7F80_0000u32 {
        // Keep high part of current mantissa but also set most significiant mantissa bit
        return ((x >> 16) | 0x0040u32) as u16;
    }

    // round and shift
    let round_bit = 0x0000_8000u32;
    if (x & round_bit) != 0 && (x & (3 * round_bit - 1)) != 0 {
        (x >> 16) as u16 + 1
    } else {
        (x >> 16) as u16
    }
}

pub(crate) fn f64_to_bf16(value: f64) -> u16 {
    // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
    // be lost on half-precision.
    let val = value.to_bits();
    let x = (val >> 32) as u32;

    // Extract IEEE754 components
    let sign = x & 0x8000_0000u32;
    let exp = x & 0x7FF0_0000u32;
    let man = x & 0x000F_FFFFu32;

    // Check for all exponent bits being set, which is Infinity or NaN
    if exp == 0x7FF0_0000u32 {
        // Set mantissa MSB for NaN (and also keep shifted mantissa bits).
        // We also have to check the last 32 bits.
        let nan_bit = if man == 0 && (val as u32 == 0) {
            0
        } else {
            0x0040u32
        };
        return ((sign >> 16) | 0x7F80u32 | nan_bit | (man >> 13)) as u16;
    }

    // The number is normalized, start assembling half precision version
    let half_sign = sign >> 16;
    // Unbias the exponent, then bias for bfloat16 precision
    let unbiased_exp = ((exp >> 20) as i64) - 1023;
    let half_exp = unbiased_exp + 127;

    // Check for exponent overflow, return +infinity
    if half_exp >= 0xFF {
        return (half_sign | 0x7F80u32) as u16;
    }

    // Check for underflow
    if half_exp <= 0 {
        // Check mantissa for what we can do
        if 7 - half_exp > 21 {
            // No rounding possibility, so this is a full underflow, return signed zero
            return half_sign as u16;
        }
        // Don't forget about hidden leading mantissa bit when assembling mantissa
        let man = man | 0x0010_0000u32;
        let mut half_man = man >> (14 - half_exp);
        // Check for rounding
        let round_bit = 1 << (13 - half_exp);
        if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
            half_man += 1;
        }
        // No exponent for subnormals
        return (half_sign | half_man) as u16;
    }

    // Rebias the exponent
    let half_exp = (half_exp as u32) << 7;
    let half_man = man >> 13;
    // Check for rounding
    let round_bit = 0x0000_1000u32;
    if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
        // Round it
        ((half_sign | half_exp | half_man) + 1) as u16
    } else {
        (half_sign | half_exp | half_man) as u16
    }
}

pub(crate) fn bf16_to_f32(i: u16) -> f32 {
    // If NaN, keep current mantissa but also set most significiant mantissa bit
    if i & 0x7FFFu16 > 0x7F80u16 {
        f32::from_bits((i as u32 | 0x0040u32) << 16)
    } else {
        f32::from_bits((i as u32) << 16)
    }
}

pub(crate) fn bf16_to_f64(i: u16) -> f64 {
    // Check for signed zero
    if i & 0x7FFFu16 == 0 {
        return f64::from_bits((i as u64) << 48);
    }

    let half_sign = (i & 0x8000u16) as u64;
    let half_exp = (i & 0x7F80u16) as u64;
    let half_man = (i & 0x007Fu16) as u64;

    // Check for an infinity or NaN when all exponent bits set
    if half_exp == 0x7F80u64 {
        // Check for signed infinity if mantissa is zero
        if half_man == 0 {
            return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64);
        } else {
            // NaN, keep current mantissa but also set most significiant mantissa bit
            return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 45));
        }
    }

    // Calculate double-precision components with adjusted exponent
    let sign = half_sign << 48;
    // Unbias exponent
    let unbiased_exp = ((half_exp as i64) >> 7) - 127;

    // Check for subnormals, which will be normalized by adjusting exponent
    if half_exp == 0 {
        // Calculate how much to adjust the exponent by
        let e = (half_man as u16).leading_zeros() - 9;

        // Rebias and adjust exponent
        let exp = ((1023 - 127 - e) as u64) << 52;
        let man = (half_man << (46 + e)) & 0xF_FFFF_FFFF_FFFFu64;
        return f64::from_bits(sign | exp | man);
    }
    // Rebias exponent for a normalized normal
    let exp = ((unbiased_exp + 1023) as u64) << 52;
    let man = (half_man & 0x007Fu64) << 45;
    f64::from_bits(sign | exp | man)
}