cipher/
stream_wrapper.rs

1use crate::{
2    errors::StreamCipherError, Block, OverflowError, SeekNum, StreamCipher, StreamCipherCore,
3    StreamCipherSeek, StreamCipherSeekCore,
4};
5use crypto_common::{
6    typenum::{IsLess, Le, NonZero, Unsigned, U256},
7    BlockSizeUser, Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser,
8};
9use inout::InOutBuf;
10#[cfg(feature = "zeroize")]
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13/// Wrapper around [`StreamCipherCore`] implementations.
14///
15/// It handles data buffering and implements the slice-based traits.
16#[derive(Clone, Default)]
17pub struct StreamCipherCoreWrapper<T: BlockSizeUser>
18where
19    T::BlockSize: IsLess<U256>,
20    Le<T::BlockSize, U256>: NonZero,
21{
22    core: T,
23    buffer: Block<T>,
24    pos: u8,
25}
26
27impl<T: StreamCipherCore> StreamCipherCoreWrapper<T>
28where
29    T::BlockSize: IsLess<U256>,
30    Le<T::BlockSize, U256>: NonZero,
31{
32    /// Return reference to the core type.
33    pub fn get_core(&self) -> &T {
34        &self.core
35    }
36
37    /// Return reference to the core type.
38    pub fn from_core(core: T) -> Self {
39        Self {
40            core,
41            buffer: Default::default(),
42            pos: 0,
43        }
44    }
45
46    /// Return current cursor position.
47    #[inline]
48    fn get_pos(&self) -> usize {
49        let pos = self.pos as usize;
50        if T::BlockSize::USIZE == 0 {
51            panic!("Block size can not be equal to zero");
52        }
53        if pos >= T::BlockSize::USIZE {
54            debug_assert!(false);
55            // SAFETY: `pos` is set only to values smaller than block size
56            unsafe { core::hint::unreachable_unchecked() }
57        }
58        self.pos as usize
59    }
60
61    /// Return size of the internal buffer in bytes.
62    #[inline]
63    fn size(&self) -> usize {
64        T::BlockSize::USIZE
65    }
66
67    #[inline]
68    fn set_pos_unchecked(&mut self, pos: usize) {
69        debug_assert!(pos < T::BlockSize::USIZE);
70        self.pos = pos as u8;
71    }
72
73    /// Return number of remaining bytes in the internal buffer.
74    #[inline]
75    fn remaining(&self) -> usize {
76        self.size() - self.get_pos()
77    }
78
79    fn check_remaining(&self, dlen: usize) -> Result<(), StreamCipherError> {
80        let rem_blocks = match self.core.remaining_blocks() {
81            Some(v) => v,
82            None => return Ok(()),
83        };
84
85        let bytes = if self.pos == 0 {
86            dlen
87        } else {
88            let rem = self.remaining();
89            if dlen > rem {
90                dlen - rem
91            } else {
92                return Ok(());
93            }
94        };
95        let bs = T::BlockSize::USIZE;
96        let blocks = if bytes % bs == 0 {
97            bytes / bs
98        } else {
99            bytes / bs + 1
100        };
101        if blocks > rem_blocks {
102            Err(StreamCipherError)
103        } else {
104            Ok(())
105        }
106    }
107}
108
109impl<T: StreamCipherCore> StreamCipher for StreamCipherCoreWrapper<T>
110where
111    T::BlockSize: IsLess<U256>,
112    Le<T::BlockSize, U256>: NonZero,
113{
114    #[inline]
115    fn try_apply_keystream_inout(
116        &mut self,
117        mut data: InOutBuf<'_, '_, u8>,
118    ) -> Result<(), StreamCipherError> {
119        self.check_remaining(data.len())?;
120
121        let pos = self.get_pos();
122        if pos != 0 {
123            let rem = &self.buffer[pos..];
124            let n = data.len();
125            if n < rem.len() {
126                data.xor_in2out(&rem[..n]);
127                self.set_pos_unchecked(pos + n);
128                return Ok(());
129            }
130            let (mut left, right) = data.split_at(rem.len());
131            data = right;
132            left.xor_in2out(rem);
133        }
134
135        let (blocks, mut leftover) = data.into_chunks();
136        self.core.apply_keystream_blocks_inout(blocks);
137
138        let n = leftover.len();
139        if n != 0 {
140            self.core.write_keystream_block(&mut self.buffer);
141            leftover.xor_in2out(&self.buffer[..n]);
142        }
143        self.set_pos_unchecked(n);
144
145        Ok(())
146    }
147}
148
149impl<T: StreamCipherSeekCore> StreamCipherSeek for StreamCipherCoreWrapper<T>
150where
151    T::BlockSize: IsLess<U256>,
152    Le<T::BlockSize, U256>: NonZero,
153{
154    fn try_current_pos<SN: SeekNum>(&self) -> Result<SN, OverflowError> {
155        let Self { core, pos, .. } = self;
156        SN::from_block_byte(core.get_block_pos(), *pos, T::BlockSize::U8)
157    }
158
159    fn try_seek<SN: SeekNum>(&mut self, new_pos: SN) -> Result<(), StreamCipherError> {
160        let Self { core, buffer, pos } = self;
161        let (block_pos, byte_pos) = new_pos.into_block_byte(T::BlockSize::U8)?;
162        core.set_block_pos(block_pos);
163        if byte_pos != 0 {
164            self.core.write_keystream_block(buffer);
165        }
166        *pos = byte_pos;
167        Ok(())
168    }
169}
170
171// Note: ideally we would only implement the InitInner trait and everything
172// else would be handled by blanket impls, but unfortunately it will
173// not work properly without mutually exclusive traits, see:
174// https://github.com/rust-lang/rfcs/issues/1053
175
176impl<T: KeySizeUser + BlockSizeUser> KeySizeUser for StreamCipherCoreWrapper<T>
177where
178    T::BlockSize: IsLess<U256>,
179    Le<T::BlockSize, U256>: NonZero,
180{
181    type KeySize = T::KeySize;
182}
183
184impl<T: IvSizeUser + BlockSizeUser> IvSizeUser for StreamCipherCoreWrapper<T>
185where
186    T::BlockSize: IsLess<U256>,
187    Le<T::BlockSize, U256>: NonZero,
188{
189    type IvSize = T::IvSize;
190}
191
192impl<T: KeyIvInit + BlockSizeUser> KeyIvInit for StreamCipherCoreWrapper<T>
193where
194    T::BlockSize: IsLess<U256>,
195    Le<T::BlockSize, U256>: NonZero,
196{
197    #[inline]
198    fn new(key: &Key<Self>, iv: &Iv<Self>) -> Self {
199        Self {
200            core: T::new(key, iv),
201            buffer: Default::default(),
202            pos: 0,
203        }
204    }
205}
206
207impl<T: KeyInit + BlockSizeUser> KeyInit for StreamCipherCoreWrapper<T>
208where
209    T::BlockSize: IsLess<U256>,
210    Le<T::BlockSize, U256>: NonZero,
211{
212    #[inline]
213    fn new(key: &Key<Self>) -> Self {
214        Self {
215            core: T::new(key),
216            buffer: Default::default(),
217            pos: 0,
218        }
219    }
220}
221
222#[cfg(feature = "zeroize")]
223impl<T> Drop for StreamCipherCoreWrapper<T>
224where
225    T: BlockSizeUser,
226    T::BlockSize: IsLess<U256>,
227    Le<T::BlockSize, U256>: NonZero,
228{
229    fn drop(&mut self) {
230        self.buffer.zeroize();
231        self.pos.zeroize();
232    }
233}
234
235#[cfg(feature = "zeroize")]
236impl<T> ZeroizeOnDrop for StreamCipherCoreWrapper<T>
237where
238    T: BlockSizeUser + ZeroizeOnDrop,
239    T::BlockSize: IsLess<U256>,
240    Le<T::BlockSize, U256>: NonZero,
241{
242}