inout/
inout_buf.rs

1use crate::{
2    errors::{IntoArrayError, NotEqualError},
3    InOut,
4};
5use core::{marker::PhantomData, slice};
6use generic_array::{ArrayLength, GenericArray};
7
8/// Custom slice type which references one immutable (input) slice and one
9/// mutable (output) slice of equal length. Input and output slices are
10/// either the same or do not overlap.
11pub struct InOutBuf<'inp, 'out, T> {
12    pub(crate) in_ptr: *const T,
13    pub(crate) out_ptr: *mut T,
14    pub(crate) len: usize,
15    pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
16}
17
18impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
19    #[inline(always)]
20    fn from(buf: &'a mut [T]) -> Self {
21        let p = buf.as_mut_ptr();
22        Self {
23            in_ptr: p,
24            out_ptr: p,
25            len: buf.len(),
26            _pd: PhantomData,
27        }
28    }
29}
30
31impl<'a, T> InOutBuf<'a, 'a, T> {
32    /// Create `InOutBuf` from a single mutable reference.
33    #[inline(always)]
34    pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
35        let p = val as *mut T;
36        Self {
37            in_ptr: p,
38            out_ptr: p,
39            len: 1,
40            _pd: PhantomData,
41        }
42    }
43}
44
45impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
46    type Item = InOut<'inp, 'out, T>;
47    type IntoIter = InOutBufIter<'inp, 'out, T>;
48
49    #[inline(always)]
50    fn into_iter(self) -> Self::IntoIter {
51        InOutBufIter { buf: self, pos: 0 }
52    }
53}
54
55impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
56    /// Create `InOutBuf` from a pair of immutable and mutable references.
57    #[inline(always)]
58    pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
59        Self {
60            in_ptr: in_val as *const T,
61            out_ptr: out_val as *mut T,
62            len: 1,
63            _pd: PhantomData,
64        }
65    }
66
67    /// Create `InOutBuf` from immutable and mutable slices.
68    ///
69    /// Returns an error if length of slices is not equal to each other.
70    #[inline(always)]
71    pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
72        if in_buf.len() != out_buf.len() {
73            Err(NotEqualError)
74        } else {
75            Ok(Self {
76                in_ptr: in_buf.as_ptr(),
77                out_ptr: out_buf.as_mut_ptr(),
78                len: in_buf.len(),
79                _pd: Default::default(),
80            })
81        }
82    }
83
84    /// Get length of the inner buffers.
85    #[inline(always)]
86    pub fn len(&self) -> usize {
87        self.len
88    }
89
90    /// Returns `true` if the buffer has a length of 0.
91    #[inline(always)]
92    pub fn is_empty(&self) -> bool {
93        self.len == 0
94    }
95
96    /// Returns `InOut` for given position.
97    ///
98    /// # Panics
99    /// If `pos` greater or equal to buffer length.
100    #[inline(always)]
101    pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
102        assert!(pos < self.len);
103        unsafe {
104            InOut {
105                in_ptr: self.in_ptr.add(pos),
106                out_ptr: self.out_ptr.add(pos),
107                _pd: PhantomData,
108            }
109        }
110    }
111
112    /// Get input slice.
113    #[inline(always)]
114    pub fn get_in<'a>(&'a self) -> &'a [T] {
115        unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
116    }
117
118    /// Get output slice.
119    #[inline(always)]
120    pub fn get_out<'a>(&'a mut self) -> &'a mut [T] {
121        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
122    }
123
124    /// Consume self and return output slice with lifetime `'a`.
125    #[inline(always)]
126    pub fn into_out(self) -> &'out mut [T] {
127        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
128    }
129
130    /// Get raw input and output pointers.
131    #[inline(always)]
132    pub fn into_raw(self) -> (*const T, *mut T) {
133        (self.in_ptr, self.out_ptr)
134    }
135
136    /// Reborrow `self`.
137    #[inline(always)]
138    pub fn reborrow<'a>(&'a mut self) -> InOutBuf<'a, 'a, T> {
139        Self {
140            in_ptr: self.in_ptr,
141            out_ptr: self.out_ptr,
142            len: self.len,
143            _pd: PhantomData,
144        }
145    }
146
147    /// Create [`InOutBuf`] from raw input and output pointers.
148    ///
149    /// # Safety
150    /// Behavior is undefined if any of the following conditions are violated:
151    /// - `in_ptr` must point to a properly initialized value of type `T` and
152    /// must be valid for reads for `len * mem::size_of::<T>()` many bytes.
153    /// - `out_ptr` must point to a properly initialized value of type `T` and
154    /// must be valid for both reads and writes for `len * mem::size_of::<T>()`
155    /// many bytes.
156    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
157    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
158    /// them must not be accessed through any other pointer (not derived from
159    /// the return value) for the duration of lifetime 'a. Both read and write
160    /// accesses are forbidden.
161    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
162    /// `out_ptr` must not be accessed through any other pointer (not derived from
163    /// the return value) for the duration of lifetime 'a. Both read and write
164    /// accesses are forbidden. The memory referenced by `in_ptr` must not be
165    /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
166    /// - The total size `len * mem::size_of::<T>()`  must be no larger than `isize::MAX`.
167    #[inline(always)]
168    pub unsafe fn from_raw(
169        in_ptr: *const T,
170        out_ptr: *mut T,
171        len: usize,
172    ) -> InOutBuf<'inp, 'out, T> {
173        Self {
174            in_ptr,
175            out_ptr,
176            len,
177            _pd: PhantomData,
178        }
179    }
180
181    /// Divides one buffer into two at `mid` index.
182    ///
183    /// The first will contain all indices from `[0, mid)` (excluding
184    /// the index `mid` itself) and the second will contain all
185    /// indices from `[mid, len)` (excluding the index `len` itself).
186    ///
187    /// # Panics
188    ///
189    /// Panics if `mid > len`.
190    #[inline(always)]
191    pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
192        assert!(mid <= self.len);
193        let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
194        (
195            InOutBuf {
196                in_ptr: self.in_ptr,
197                out_ptr: self.out_ptr,
198                len: mid,
199                _pd: PhantomData,
200            },
201            InOutBuf {
202                in_ptr: tail_in_ptr,
203                out_ptr: tail_out_ptr,
204                len: self.len() - mid,
205                _pd: PhantomData,
206            },
207        )
208    }
209
210    /// Partition buffer into 2 parts: buffer of arrays and tail.
211    #[inline(always)]
212    pub fn into_chunks<N: ArrayLength<T>>(
213        self,
214    ) -> (
215        InOutBuf<'inp, 'out, GenericArray<T, N>>,
216        InOutBuf<'inp, 'out, T>,
217    ) {
218        let chunks = self.len() / N::USIZE;
219        let tail_pos = N::USIZE * chunks;
220        let tail_len = self.len() - tail_pos;
221        unsafe {
222            let chunks = InOutBuf {
223                in_ptr: self.in_ptr as *const GenericArray<T, N>,
224                out_ptr: self.out_ptr as *mut GenericArray<T, N>,
225                len: chunks,
226                _pd: PhantomData,
227            };
228            let tail = InOutBuf {
229                in_ptr: self.in_ptr.add(tail_pos),
230                out_ptr: self.out_ptr.add(tail_pos),
231                len: tail_len,
232                _pd: PhantomData,
233            };
234            (chunks, tail)
235        }
236    }
237}
238
239impl<'inp, 'out> InOutBuf<'inp, 'out, u8> {
240    /// XORs `data` with values behind the input slice and write
241    /// result to the output slice.
242    ///
243    /// # Panics
244    /// If `data` length is not equal to the buffer length.
245    #[inline(always)]
246    #[allow(clippy::needless_range_loop)]
247    pub fn xor_in2out(&mut self, data: &[u8]) {
248        assert_eq!(self.len(), data.len());
249        unsafe {
250            for i in 0..data.len() {
251                let in_ptr = self.in_ptr.add(i);
252                let out_ptr = self.out_ptr.add(i);
253                *out_ptr = *in_ptr ^ data[i];
254            }
255        }
256    }
257}
258
259impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, GenericArray<T, N>>> for InOutBuf<'inp, 'out, T>
260where
261    N: ArrayLength<T>,
262{
263    type Error = IntoArrayError;
264
265    #[inline(always)]
266    fn try_into(self) -> Result<InOut<'inp, 'out, GenericArray<T, N>>, Self::Error> {
267        if self.len() == N::USIZE {
268            Ok(InOut {
269                in_ptr: self.in_ptr as *const _,
270                out_ptr: self.out_ptr as *mut _,
271                _pd: PhantomData,
272            })
273        } else {
274            Err(IntoArrayError)
275        }
276    }
277}
278
279/// Iterator over [`InOutBuf`].
280pub struct InOutBufIter<'inp, 'out, T> {
281    buf: InOutBuf<'inp, 'out, T>,
282    pos: usize,
283}
284
285impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
286    type Item = InOut<'inp, 'out, T>;
287
288    #[inline(always)]
289    fn next(&mut self) -> Option<Self::Item> {
290        if self.buf.len() == self.pos {
291            return None;
292        }
293        let res = unsafe {
294            InOut {
295                in_ptr: self.buf.in_ptr.add(self.pos),
296                out_ptr: self.buf.out_ptr.add(self.pos),
297                _pd: PhantomData,
298            }
299        };
300        self.pos += 1;
301        Some(res)
302    }
303}