chacha20/backends/
avx2.rs

1use crate::{Block, StreamClosure, Unsigned, STATE_WORDS};
2use cipher::{
3    consts::{U4, U64},
4    BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamBackend,
5};
6use core::marker::PhantomData;
7
8#[cfg(target_arch = "x86")]
9use core::arch::x86::*;
10#[cfg(target_arch = "x86_64")]
11use core::arch::x86_64::*;
12
13/// Number of blocks processed in parallel.
14const PAR_BLOCKS: usize = 4;
15/// Number of `__m256i` to store parallel blocks.
16const N: usize = PAR_BLOCKS / 2;
17
18#[inline]
19#[target_feature(enable = "avx2")]
20pub(crate) unsafe fn inner<R, F>(state: &mut [u32; STATE_WORDS], f: F)
21where
22    R: Unsigned,
23    F: StreamClosure<BlockSize = U64>,
24{
25    let state_ptr = state.as_ptr() as *const __m128i;
26    let v = [
27        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
28        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
29        _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
30    ];
31    let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
32    c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
33    let mut ctr = [c; N];
34    for i in 0..N {
35        ctr[i] = c;
36        c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
37    }
38    let mut backend = Backend::<R> {
39        v,
40        ctr,
41        _pd: PhantomData,
42    };
43
44    f.call(&mut backend);
45
46    state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
47}
48
49struct Backend<R: Unsigned> {
50    v: [__m256i; 3],
51    ctr: [__m256i; N],
52    _pd: PhantomData<R>,
53}
54
55impl<R: Unsigned> BlockSizeUser for Backend<R> {
56    type BlockSize = U64;
57}
58
59impl<R: Unsigned> ParBlocksSizeUser for Backend<R> {
60    type ParBlocksSize = U4;
61}
62
63impl<R: Unsigned> StreamBackend for Backend<R> {
64    #[inline(always)]
65    fn gen_ks_block(&mut self, block: &mut Block) {
66        unsafe {
67            let res = rounds::<R>(&self.v, &self.ctr);
68            for c in self.ctr.iter_mut() {
69                *c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1));
70            }
71
72            let res0: [__m128i; 8] = core::mem::transmute(res[0]);
73
74            let block_ptr = block.as_mut_ptr() as *mut __m128i;
75            for i in 0..4 {
76                _mm_storeu_si128(block_ptr.add(i), res0[2 * i]);
77            }
78        }
79    }
80
81    #[inline(always)]
82    fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks<Self>) {
83        unsafe {
84            let vs = rounds::<R>(&self.v, &self.ctr);
85
86            let pb = PAR_BLOCKS as i32;
87            for c in self.ctr.iter_mut() {
88                *c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
89            }
90
91            let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;
92            for v in vs {
93                let t: [__m128i; 8] = core::mem::transmute(v);
94                for i in 0..4 {
95                    _mm_storeu_si128(block_ptr.add(i), t[2 * i]);
96                    _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
97                }
98                block_ptr = block_ptr.add(8);
99            }
100        }
101    }
102}
103
104#[inline]
105#[target_feature(enable = "avx2")]
106unsafe fn rounds<R: Unsigned>(v: &[__m256i; 3], c: &[__m256i; N]) -> [[__m256i; 4]; N] {
107    let mut vs: [[__m256i; 4]; N] = [[_mm256_setzero_si256(); 4]; N];
108    for i in 0..N {
109        vs[i] = [v[0], v[1], v[2], c[i]];
110    }
111    for _ in 0..R::USIZE {
112        double_quarter_round(&mut vs);
113    }
114
115    for i in 0..N {
116        for j in 0..3 {
117            vs[i][j] = _mm256_add_epi32(vs[i][j], v[j]);
118        }
119        vs[i][3] = _mm256_add_epi32(vs[i][3], c[i]);
120    }
121
122    vs
123}
124
125#[inline]
126#[target_feature(enable = "avx2")]
127unsafe fn double_quarter_round(v: &mut [[__m256i; 4]; N]) {
128    add_xor_rot(v);
129    rows_to_cols(v);
130    add_xor_rot(v);
131    cols_to_rows(v);
132}
133
134/// The goal of this function is to transform the state words from:
135/// ```text
136/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
137/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
138/// [c0, c1, c2, c3]    [ 8,  9, 10, 11]
139/// [d0, d1, d2, d3]    [12, 13, 14, 15]
140/// ```
141///
142/// to:
143/// ```text
144/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
145/// [b1, b2, b3, b0] == [ 5,  6,  7,  4]
146/// [c2, c3, c0, c1]    [10, 11,  8,  9]
147/// [d3, d0, d1, d2]    [15, 12, 13, 14]
148/// ```
149///
150/// so that we can apply [`add_xor_rot`] to the resulting columns, and have it compute the
151/// "diagonal rounds" (as defined in RFC 7539) in parallel. In practice, this shuffle is
152/// non-optimal: the last state word to be altered in `add_xor_rot` is `b`, so the shuffle
153/// blocks on the result of `b` being calculated.
154///
155/// We can optimize this by observing that the four quarter rounds in `add_xor_rot` are
156/// data-independent: they only access a single column of the state, and thus the order of
157/// the columns does not matter. We therefore instead shuffle the other three state words,
158/// to obtain the following equivalent layout:
159/// ```text
160/// [a3, a0, a1, a2]    [ 3,  0,  1,  2]
161/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
162/// [c1, c2, c3, c0]    [ 9, 10, 11,  8]
163/// [d2, d3, d0, d1]    [14, 15, 12, 13]
164/// ```
165///
166/// See https://github.com/sneves/blake2-avx2/pull/4 for additional details. The earliest
167/// known occurrence of this optimization is in floodyberry's SSE4 ChaCha code from 2014:
168/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
169#[inline]
170#[target_feature(enable = "avx2")]
171unsafe fn rows_to_cols(vs: &mut [[__m256i; 4]; N]) {
172    // c >>>= 32; d >>>= 64; a >>>= 96;
173    for [a, _, c, d] in vs {
174        *c = _mm256_shuffle_epi32(*c, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
175        *d = _mm256_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
176        *a = _mm256_shuffle_epi32(*a, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
177    }
178}
179
180/// The goal of this function is to transform the state words from:
181/// ```text
182/// [a3, a0, a1, a2]    [ 3,  0,  1,  2]
183/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
184/// [c1, c2, c3, c0]    [ 9, 10, 11,  8]
185/// [d2, d3, d0, d1]    [14, 15, 12, 13]
186/// ```
187///
188/// to:
189/// ```text
190/// [a0, a1, a2, a3]    [ 0,  1,  2,  3]
191/// [b0, b1, b2, b3] == [ 4,  5,  6,  7]
192/// [c0, c1, c2, c3]    [ 8,  9, 10, 11]
193/// [d0, d1, d2, d3]    [12, 13, 14, 15]
194/// ```
195///
196/// reversing the transformation of [`rows_to_cols`].
197#[inline]
198#[target_feature(enable = "avx2")]
199unsafe fn cols_to_rows(vs: &mut [[__m256i; 4]; N]) {
200    // c <<<= 32; d <<<= 64; a <<<= 96;
201    for [a, _, c, d] in vs {
202        *c = _mm256_shuffle_epi32(*c, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
203        *d = _mm256_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
204        *a = _mm256_shuffle_epi32(*a, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
205    }
206}
207
208#[inline]
209#[target_feature(enable = "avx2")]
210unsafe fn add_xor_rot(vs: &mut [[__m256i; 4]; N]) {
211    let rol16_mask = _mm256_set_epi64x(
212        0x0d0c_0f0e_0908_0b0a,
213        0x0504_0706_0100_0302,
214        0x0d0c_0f0e_0908_0b0a,
215        0x0504_0706_0100_0302,
216    );
217    let rol8_mask = _mm256_set_epi64x(
218        0x0e0d_0c0f_0a09_080b,
219        0x0605_0407_0201_0003,
220        0x0e0d_0c0f_0a09_080b,
221        0x0605_0407_0201_0003,
222    );
223
224    // a += b; d ^= a; d <<<= (16, 16, 16, 16);
225    for [a, b, _, d] in vs.iter_mut() {
226        *a = _mm256_add_epi32(*a, *b);
227        *d = _mm256_xor_si256(*d, *a);
228        *d = _mm256_shuffle_epi8(*d, rol16_mask);
229    }
230
231    // c += d; b ^= c; b <<<= (12, 12, 12, 12);
232    for [_, b, c, d] in vs.iter_mut() {
233        *c = _mm256_add_epi32(*c, *d);
234        *b = _mm256_xor_si256(*b, *c);
235        *b = _mm256_xor_si256(_mm256_slli_epi32(*b, 12), _mm256_srli_epi32(*b, 20));
236    }
237
238    // a += b; d ^= a; d <<<= (8, 8, 8, 8);
239    for [a, b, _, d] in vs.iter_mut() {
240        *a = _mm256_add_epi32(*a, *b);
241        *d = _mm256_xor_si256(*d, *a);
242        *d = _mm256_shuffle_epi8(*d, rol8_mask);
243    }
244
245    // c += d; b ^= c; b <<<= (7, 7, 7, 7);
246    for [_, b, c, d] in vs.iter_mut() {
247        *c = _mm256_add_epi32(*c, *d);
248        *b = _mm256_xor_si256(*b, *c);
249        *b = _mm256_xor_si256(_mm256_slli_epi32(*b, 7), _mm256_srli_epi32(*b, 25));
250    }
251}