cipher/
stream_core.rs

1use crate::{ParBlocks, ParBlocksSizeUser, StreamCipherError};
2use crypto_common::{
3    generic_array::{ArrayLength, GenericArray},
4    typenum::Unsigned,
5    Block, BlockSizeUser,
6};
7use inout::{InOut, InOutBuf};
8
9/// Trait implemented by stream cipher backends.
10pub trait StreamBackend: ParBlocksSizeUser {
11    /// Generate keystream block.
12    fn gen_ks_block(&mut self, block: &mut Block<Self>);
13
14    /// Generate keystream blocks in parallel.
15    #[inline(always)]
16    fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks<Self>) {
17        for block in blocks {
18            self.gen_ks_block(block);
19        }
20    }
21
22    /// Generate keystream blocks. Length of the buffer MUST be smaller
23    /// than `Self::ParBlocksSize`.
24    #[inline(always)]
25    fn gen_tail_blocks(&mut self, blocks: &mut [Block<Self>]) {
26        assert!(blocks.len() < Self::ParBlocksSize::USIZE);
27        for block in blocks {
28            self.gen_ks_block(block);
29        }
30    }
31}
32
33/// Trait for [`StreamBackend`] users.
34///
35/// This trait is used to define rank-2 closures.
36pub trait StreamClosure: BlockSizeUser {
37    /// Execute closure with the provided stream cipher backend.
38    fn call<B: StreamBackend<BlockSize = Self::BlockSize>>(self, backend: &mut B);
39}
40
41/// Block-level synchronous stream ciphers.
42pub trait StreamCipherCore: BlockSizeUser + Sized {
43    /// Return number of remaining blocks before cipher wraps around.
44    ///
45    /// Returns `None` if number of remaining blocks can not be computed
46    /// (e.g. in ciphers based on the sponge construction) or it's too big
47    /// to fit into `usize`.
48    fn remaining_blocks(&self) -> Option<usize>;
49
50    /// Process data using backend provided to the rank-2 closure.
51    fn process_with_backend(&mut self, f: impl StreamClosure<BlockSize = Self::BlockSize>);
52
53    /// Write keystream block.
54    ///
55    /// WARNING: this method does not check number of remaining blocks!
56    #[inline]
57    fn write_keystream_block(&mut self, block: &mut Block<Self>) {
58        self.process_with_backend(WriteBlockCtx { block });
59    }
60
61    /// Write keystream blocks.
62    ///
63    /// WARNING: this method does not check number of remaining blocks!
64    #[inline]
65    fn write_keystream_blocks(&mut self, blocks: &mut [Block<Self>]) {
66        self.process_with_backend(WriteBlocksCtx { blocks });
67    }
68
69    /// Apply keystream block.
70    ///
71    /// WARNING: this method does not check number of remaining blocks!
72    #[inline]
73    fn apply_keystream_block_inout(&mut self, block: InOut<'_, '_, Block<Self>>) {
74        self.process_with_backend(ApplyBlockCtx { block });
75    }
76
77    /// Apply keystream blocks.
78    ///
79    /// WARNING: this method does not check number of remaining blocks!
80    #[inline]
81    fn apply_keystream_blocks(&mut self, blocks: &mut [Block<Self>]) {
82        self.process_with_backend(ApplyBlocksCtx {
83            blocks: blocks.into(),
84        });
85    }
86
87    /// Apply keystream blocks.
88    ///
89    /// WARNING: this method does not check number of remaining blocks!
90    #[inline]
91    fn apply_keystream_blocks_inout(&mut self, blocks: InOutBuf<'_, '_, Block<Self>>) {
92        self.process_with_backend(ApplyBlocksCtx { blocks });
93    }
94
95    /// Try to apply keystream to data not divided into blocks.
96    ///
97    /// Consumes cipher since it may consume final keystream block only
98    /// partially.
99    ///
100    /// Returns an error if number of remaining blocks is not sufficient
101    /// for processing the input data.
102    #[inline]
103    fn try_apply_keystream_partial(
104        mut self,
105        mut buf: InOutBuf<'_, '_, u8>,
106    ) -> Result<(), StreamCipherError> {
107        if let Some(rem) = self.remaining_blocks() {
108            let blocks = if buf.len() % Self::BlockSize::USIZE == 0 {
109                buf.len() % Self::BlockSize::USIZE
110            } else {
111                buf.len() % Self::BlockSize::USIZE + 1
112            };
113            if blocks > rem {
114                return Err(StreamCipherError);
115            }
116        }
117
118        if buf.len() > Self::BlockSize::USIZE {
119            let (blocks, tail) = buf.into_chunks();
120            self.apply_keystream_blocks_inout(blocks);
121            buf = tail;
122        }
123        let n = buf.len();
124        if n == 0 {
125            return Ok(());
126        }
127        let mut block = Block::<Self>::default();
128        block[..n].copy_from_slice(buf.get_in());
129        let t = InOutBuf::from_mut(&mut block);
130        self.apply_keystream_blocks_inout(t);
131        buf.get_out().copy_from_slice(&block[..n]);
132        Ok(())
133    }
134
135    /// Try to apply keystream to data not divided into blocks.
136    ///
137    /// Consumes cipher since it may consume final keystream block only
138    /// partially.
139    ///
140    /// # Panics
141    /// If number of remaining blocks is not sufficient for processing the
142    /// input data.
143    #[inline]
144    fn apply_keystream_partial(self, buf: InOutBuf<'_, '_, u8>) {
145        self.try_apply_keystream_partial(buf).unwrap()
146    }
147}
148
149// note: unfortunately, currently we can not write blanket impls of
150// `BlockEncryptMut` and `BlockDecryptMut` for `T: StreamCipherCore`
151// since it requires mutually exclusive traits, see:
152// https://github.com/rust-lang/rfcs/issues/1053
153
154/// Counter type usable with [`StreamCipherCore`].
155///
156/// This trait is implemented for `i32`, `u32`, `u64`, `u128`, and `usize`.
157/// It's not intended to be implemented in third-party crates, but doing so
158/// is not forbidden.
159pub trait Counter:
160    TryFrom<i32>
161    + TryFrom<u32>
162    + TryFrom<u64>
163    + TryFrom<u128>
164    + TryFrom<usize>
165    + TryInto<i32>
166    + TryInto<u32>
167    + TryInto<u64>
168    + TryInto<u128>
169    + TryInto<usize>
170{
171}
172
173/// Block-level seeking trait for stream ciphers.
174pub trait StreamCipherSeekCore: StreamCipherCore {
175    /// Counter type used inside stream cipher.
176    type Counter: Counter;
177
178    /// Get current block position.
179    fn get_block_pos(&self) -> Self::Counter;
180
181    /// Set block position.
182    fn set_block_pos(&mut self, pos: Self::Counter);
183}
184
185macro_rules! impl_counter {
186    {$($t:ty )*} => {
187        $( impl Counter for $t { } )*
188    };
189}
190
191impl_counter! { u32 u64 u128 }
192
193/// Partition buffer into 2 parts: buffer of arrays and tail.
194///
195/// In case if `N` is less or equal to 1, buffer of arrays has length
196/// of zero and tail is equal to `self`.
197#[inline]
198fn into_chunks<T, N: ArrayLength<T>>(buf: &mut [T]) -> (&mut [GenericArray<T, N>], &mut [T]) {
199    use core::slice;
200    if N::USIZE <= 1 {
201        return (&mut [], buf);
202    }
203    let chunks_len = buf.len() / N::USIZE;
204    let tail_pos = N::USIZE * chunks_len;
205    let tail_len = buf.len() - tail_pos;
206    unsafe {
207        let ptr = buf.as_mut_ptr();
208        let chunks = slice::from_raw_parts_mut(ptr as *mut GenericArray<T, N>, chunks_len);
209        let tail = slice::from_raw_parts_mut(ptr.add(tail_pos), tail_len);
210        (chunks, tail)
211    }
212}
213
214struct WriteBlockCtx<'a, BS: ArrayLength<u8>> {
215    block: &'a mut Block<Self>,
216}
217impl<'a, BS: ArrayLength<u8>> BlockSizeUser for WriteBlockCtx<'a, BS> {
218    type BlockSize = BS;
219}
220impl<'a, BS: ArrayLength<u8>> StreamClosure for WriteBlockCtx<'a, BS> {
221    #[inline(always)]
222    fn call<B: StreamBackend<BlockSize = BS>>(self, backend: &mut B) {
223        backend.gen_ks_block(self.block);
224    }
225}
226
227struct WriteBlocksCtx<'a, BS: ArrayLength<u8>> {
228    blocks: &'a mut [Block<Self>],
229}
230impl<'a, BS: ArrayLength<u8>> BlockSizeUser for WriteBlocksCtx<'a, BS> {
231    type BlockSize = BS;
232}
233impl<'a, BS: ArrayLength<u8>> StreamClosure for WriteBlocksCtx<'a, BS> {
234    #[inline(always)]
235    fn call<B: StreamBackend<BlockSize = BS>>(self, backend: &mut B) {
236        if B::ParBlocksSize::USIZE > 1 {
237            let (chunks, tail) = into_chunks::<_, B::ParBlocksSize>(self.blocks);
238            for chunk in chunks {
239                backend.gen_par_ks_blocks(chunk);
240            }
241            backend.gen_tail_blocks(tail);
242        } else {
243            for block in self.blocks {
244                backend.gen_ks_block(block);
245            }
246        }
247    }
248}
249
250struct ApplyBlockCtx<'inp, 'out, BS: ArrayLength<u8>> {
251    block: InOut<'inp, 'out, Block<Self>>,
252}
253
254impl<'inp, 'out, BS: ArrayLength<u8>> BlockSizeUser for ApplyBlockCtx<'inp, 'out, BS> {
255    type BlockSize = BS;
256}
257
258impl<'inp, 'out, BS: ArrayLength<u8>> StreamClosure for ApplyBlockCtx<'inp, 'out, BS> {
259    #[inline(always)]
260    fn call<B: StreamBackend<BlockSize = BS>>(mut self, backend: &mut B) {
261        let mut t = Default::default();
262        backend.gen_ks_block(&mut t);
263        self.block.xor_in2out(&t);
264    }
265}
266
267struct ApplyBlocksCtx<'inp, 'out, BS: ArrayLength<u8>> {
268    blocks: InOutBuf<'inp, 'out, Block<Self>>,
269}
270
271impl<'inp, 'out, BS: ArrayLength<u8>> BlockSizeUser for ApplyBlocksCtx<'inp, 'out, BS> {
272    type BlockSize = BS;
273}
274
275impl<'inp, 'out, BS: ArrayLength<u8>> StreamClosure for ApplyBlocksCtx<'inp, 'out, BS> {
276    #[inline(always)]
277    #[allow(clippy::needless_range_loop)]
278    fn call<B: StreamBackend<BlockSize = BS>>(self, backend: &mut B) {
279        if B::ParBlocksSize::USIZE > 1 {
280            let (chunks, mut tail) = self.blocks.into_chunks::<B::ParBlocksSize>();
281            for mut chunk in chunks {
282                let mut tmp = Default::default();
283                backend.gen_par_ks_blocks(&mut tmp);
284                chunk.xor_in2out(&tmp);
285            }
286            let n = tail.len();
287            let mut buf = GenericArray::<_, B::ParBlocksSize>::default();
288            let ks = &mut buf[..n];
289            backend.gen_tail_blocks(ks);
290            for i in 0..n {
291                tail.get(i).xor_in2out(&ks[i]);
292            }
293        } else {
294            for mut block in self.blocks {
295                let mut t = Default::default();
296                backend.gen_ks_block(&mut t);
297                block.xor_in2out(&t);
298            }
299        }
300    }
301}