chacha20/backends/
sse2.rs

1use crate::{Block, StreamClosure, Unsigned, STATE_WORDS};
2use cipher::{
3    consts::{U1, U64},
4    BlockSizeUser, ParBlocksSizeUser, StreamBackend,
5};
6use core::marker::PhantomData;
7
8#[cfg(target_arch = "x86")]
9use core::arch::x86::*;
10#[cfg(target_arch = "x86_64")]
11use core::arch::x86_64::*;
12
13#[inline]
14#[target_feature(enable = "sse2")]
15pub(crate) unsafe fn inner<R, F>(state: &mut [u32; STATE_WORDS], f: F)
16where
17    R: Unsigned,
18    F: StreamClosure<BlockSize = U64>,
19{
20    let state_ptr = state.as_ptr() as *const __m128i;
21    let mut backend = Backend::<R> {
22        v: [
23            _mm_loadu_si128(state_ptr.add(0)),
24            _mm_loadu_si128(state_ptr.add(1)),
25            _mm_loadu_si128(state_ptr.add(2)),
26            _mm_loadu_si128(state_ptr.add(3)),
27        ],
28        _pd: PhantomData,
29    };
30
31    f.call(&mut backend);
32
33    state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32;
34}
35
36struct Backend<R: Unsigned> {
37    v: [__m128i; 4],
38    _pd: PhantomData<R>,
39}
40
41impl<R: Unsigned> BlockSizeUser for Backend<R> {
42    type BlockSize = U64;
43}
44
45impl<R: Unsigned> ParBlocksSizeUser for Backend<R> {
46    type ParBlocksSize = U1;
47}
48
49impl<R: Unsigned> StreamBackend for Backend<R> {
50    #[inline(always)]
51    fn gen_ks_block(&mut self, block: &mut Block) {
52        unsafe {
53            let res = rounds::<R>(&self.v);
54            self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1));
55
56            let block_ptr = block.as_mut_ptr() as *mut __m128i;
57            for i in 0..4 {
58                _mm_storeu_si128(block_ptr.add(i), res[i]);
59            }
60        }
61    }
62}
63
64#[inline]
65#[target_feature(enable = "sse2")]
66unsafe fn rounds<R: Unsigned>(v: &[__m128i; 4]) -> [__m128i; 4] {
67    let mut res = *v;
68    for _ in 0..R::USIZE {
69        double_quarter_round(&mut res);
70    }
71
72    for i in 0..4 {
73        res[i] = _mm_add_epi32(res[i], v[i]);
74    }
75
76    res
77}
78
79#[inline]
80#[target_feature(enable = "sse2")]
81unsafe fn double_quarter_round(v: &mut [__m128i; 4]) {
82    add_xor_rot(v);
83    rows_to_cols(v);
84    add_xor_rot(v);
85    cols_to_rows(v);
86}
87
88/// The goal of this function is to transform the state words from:
89/// ```text
90/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
91/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
92/// [c0, c1, c2, c3]    [ 8,  9, 10, 11]
93/// [d0, d1, d2, d3]    [12, 13, 14, 15]
94/// ```
95///
96/// to:
97/// ```text
98/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
99/// [b1, b2, b3, b0] == [ 5,  6,  7,  4]
100/// [c2, c3, c0, c1]    [10, 11,  8,  9]
101/// [d3, d0, d1, d2]    [15, 12, 13, 14]
102/// ```
103///
104/// so that we can apply [`add_xor_rot`] to the resulting columns, and have it compute the
105/// "diagonal rounds" (as defined in RFC 7539) in parallel. In practice, this shuffle is
106/// non-optimal: the last state word to be altered in `add_xor_rot` is `b`, so the shuffle
107/// blocks on the result of `b` being calculated.
108///
109/// We can optimize this by observing that the four quarter rounds in `add_xor_rot` are
110/// data-independent: they only access a single column of the state, and thus the order of
111/// the columns does not matter. We therefore instead shuffle the other three state words,
112/// to obtain the following equivalent layout:
113/// ```text
114/// [a3, a0, a1, a2]    [ 3,  0,  1,  2]
115/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
116/// [c1, c2, c3, c0]    [ 9, 10, 11,  8]
117/// [d2, d3, d0, d1]    [14, 15, 12, 13]
118/// ```
119///
120/// See https://github.com/sneves/blake2-avx2/pull/4 for additional details. The earliest
121/// known occurrence of this optimization is in floodyberry's SSE4 ChaCha code from 2014:
122/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
123#[inline]
124#[target_feature(enable = "sse2")]
125unsafe fn rows_to_cols([a, _, c, d]: &mut [__m128i; 4]) {
126    // c >>>= 32; d >>>= 64; a >>>= 96;
127    *c = _mm_shuffle_epi32(*c, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
128    *d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
129    *a = _mm_shuffle_epi32(*a, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
130}
131
132/// The goal of this function is to transform the state words from:
133/// ```text
134/// [a3, a0, a1, a2]    [ 3,  0,  1,  2]
135/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
136/// [c1, c2, c3, c0]    [ 9, 10, 11,  8]
137/// [d2, d3, d0, d1]    [14, 15, 12, 13]
138/// ```
139///
140/// to:
141/// ```text
142/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
143/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
144/// [c0, c1, c2, c3]    [ 8,  9, 10, 11]
145/// [d0, d1, d2, d3]    [12, 13, 14, 15]
146/// ```
147///
148/// reversing the transformation of [`rows_to_cols`].
149#[inline]
150#[target_feature(enable = "sse2")]
151unsafe fn cols_to_rows([a, _, c, d]: &mut [__m128i; 4]) {
152    // c <<<= 32; d <<<= 64; a <<<= 96;
153    *c = _mm_shuffle_epi32(*c, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
154    *d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
155    *a = _mm_shuffle_epi32(*a, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
156}
157
158#[inline]
159#[target_feature(enable = "sse2")]
160unsafe fn add_xor_rot([a, b, c, d]: &mut [__m128i; 4]) {
161    // a += b; d ^= a; d <<<= (16, 16, 16, 16);
162    *a = _mm_add_epi32(*a, *b);
163    *d = _mm_xor_si128(*d, *a);
164    *d = _mm_xor_si128(_mm_slli_epi32(*d, 16), _mm_srli_epi32(*d, 16));
165
166    // c += d; b ^= c; b <<<= (12, 12, 12, 12);
167    *c = _mm_add_epi32(*c, *d);
168    *b = _mm_xor_si128(*b, *c);
169    *b = _mm_xor_si128(_mm_slli_epi32(*b, 12), _mm_srli_epi32(*b, 20));
170
171    // a += b; d ^= a; d <<<= (8, 8, 8, 8);
172    *a = _mm_add_epi32(*a, *b);
173    *d = _mm_xor_si128(*d, *a);
174    *d = _mm_xor_si128(_mm_slli_epi32(*d, 8), _mm_srli_epi32(*d, 24));
175
176    // c += d; b ^= c; b <<<= (7, 7, 7, 7);
177    *c = _mm_add_epi32(*c, *d);
178    *b = _mm_xor_si128(*b, *c);
179    *b = _mm_xor_si128(_mm_slli_epi32(*b, 7), _mm_srli_epi32(*b, 25));
180}