1use crate::{
2 errors::{IntoArrayError, NotEqualError},
3 InOut,
4};
5use core::{marker::PhantomData, slice};
6use generic_array::{ArrayLength, GenericArray};
7
8pub struct InOutBuf<'inp, 'out, T> {
12 pub(crate) in_ptr: *const T,
13 pub(crate) out_ptr: *mut T,
14 pub(crate) len: usize,
15 pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
16}
17
18impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
19 #[inline(always)]
20 fn from(buf: &'a mut [T]) -> Self {
21 let p = buf.as_mut_ptr();
22 Self {
23 in_ptr: p,
24 out_ptr: p,
25 len: buf.len(),
26 _pd: PhantomData,
27 }
28 }
29}
30
31impl<'a, T> InOutBuf<'a, 'a, T> {
32 #[inline(always)]
34 pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
35 let p = val as *mut T;
36 Self {
37 in_ptr: p,
38 out_ptr: p,
39 len: 1,
40 _pd: PhantomData,
41 }
42 }
43}
44
45impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
46 type Item = InOut<'inp, 'out, T>;
47 type IntoIter = InOutBufIter<'inp, 'out, T>;
48
49 #[inline(always)]
50 fn into_iter(self) -> Self::IntoIter {
51 InOutBufIter { buf: self, pos: 0 }
52 }
53}
54
55impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
56 #[inline(always)]
58 pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
59 Self {
60 in_ptr: in_val as *const T,
61 out_ptr: out_val as *mut T,
62 len: 1,
63 _pd: PhantomData,
64 }
65 }
66
67 #[inline(always)]
71 pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
72 if in_buf.len() != out_buf.len() {
73 Err(NotEqualError)
74 } else {
75 Ok(Self {
76 in_ptr: in_buf.as_ptr(),
77 out_ptr: out_buf.as_mut_ptr(),
78 len: in_buf.len(),
79 _pd: Default::default(),
80 })
81 }
82 }
83
84 #[inline(always)]
86 pub fn len(&self) -> usize {
87 self.len
88 }
89
90 #[inline(always)]
92 pub fn is_empty(&self) -> bool {
93 self.len == 0
94 }
95
96 #[inline(always)]
101 pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
102 assert!(pos < self.len);
103 unsafe {
104 InOut {
105 in_ptr: self.in_ptr.add(pos),
106 out_ptr: self.out_ptr.add(pos),
107 _pd: PhantomData,
108 }
109 }
110 }
111
112 #[inline(always)]
114 pub fn get_in<'a>(&'a self) -> &'a [T] {
115 unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
116 }
117
118 #[inline(always)]
120 pub fn get_out<'a>(&'a mut self) -> &'a mut [T] {
121 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
122 }
123
124 #[inline(always)]
126 pub fn into_out(self) -> &'out mut [T] {
127 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
128 }
129
130 #[inline(always)]
132 pub fn into_raw(self) -> (*const T, *mut T) {
133 (self.in_ptr, self.out_ptr)
134 }
135
136 #[inline(always)]
138 pub fn reborrow<'a>(&'a mut self) -> InOutBuf<'a, 'a, T> {
139 Self {
140 in_ptr: self.in_ptr,
141 out_ptr: self.out_ptr,
142 len: self.len,
143 _pd: PhantomData,
144 }
145 }
146
147 #[inline(always)]
168 pub unsafe fn from_raw(
169 in_ptr: *const T,
170 out_ptr: *mut T,
171 len: usize,
172 ) -> InOutBuf<'inp, 'out, T> {
173 Self {
174 in_ptr,
175 out_ptr,
176 len,
177 _pd: PhantomData,
178 }
179 }
180
181 #[inline(always)]
191 pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
192 assert!(mid <= self.len);
193 let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
194 (
195 InOutBuf {
196 in_ptr: self.in_ptr,
197 out_ptr: self.out_ptr,
198 len: mid,
199 _pd: PhantomData,
200 },
201 InOutBuf {
202 in_ptr: tail_in_ptr,
203 out_ptr: tail_out_ptr,
204 len: self.len() - mid,
205 _pd: PhantomData,
206 },
207 )
208 }
209
210 #[inline(always)]
212 pub fn into_chunks<N: ArrayLength<T>>(
213 self,
214 ) -> (
215 InOutBuf<'inp, 'out, GenericArray<T, N>>,
216 InOutBuf<'inp, 'out, T>,
217 ) {
218 let chunks = self.len() / N::USIZE;
219 let tail_pos = N::USIZE * chunks;
220 let tail_len = self.len() - tail_pos;
221 unsafe {
222 let chunks = InOutBuf {
223 in_ptr: self.in_ptr as *const GenericArray<T, N>,
224 out_ptr: self.out_ptr as *mut GenericArray<T, N>,
225 len: chunks,
226 _pd: PhantomData,
227 };
228 let tail = InOutBuf {
229 in_ptr: self.in_ptr.add(tail_pos),
230 out_ptr: self.out_ptr.add(tail_pos),
231 len: tail_len,
232 _pd: PhantomData,
233 };
234 (chunks, tail)
235 }
236 }
237}
238
239impl<'inp, 'out> InOutBuf<'inp, 'out, u8> {
240 #[inline(always)]
246 #[allow(clippy::needless_range_loop)]
247 pub fn xor_in2out(&mut self, data: &[u8]) {
248 assert_eq!(self.len(), data.len());
249 unsafe {
250 for i in 0..data.len() {
251 let in_ptr = self.in_ptr.add(i);
252 let out_ptr = self.out_ptr.add(i);
253 *out_ptr = *in_ptr ^ data[i];
254 }
255 }
256 }
257}
258
259impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, GenericArray<T, N>>> for InOutBuf<'inp, 'out, T>
260where
261 N: ArrayLength<T>,
262{
263 type Error = IntoArrayError;
264
265 #[inline(always)]
266 fn try_into(self) -> Result<InOut<'inp, 'out, GenericArray<T, N>>, Self::Error> {
267 if self.len() == N::USIZE {
268 Ok(InOut {
269 in_ptr: self.in_ptr as *const _,
270 out_ptr: self.out_ptr as *mut _,
271 _pd: PhantomData,
272 })
273 } else {
274 Err(IntoArrayError)
275 }
276 }
277}
278
279pub struct InOutBufIter<'inp, 'out, T> {
281 buf: InOutBuf<'inp, 'out, T>,
282 pos: usize,
283}
284
285impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
286 type Item = InOut<'inp, 'out, T>;
287
288 #[inline(always)]
289 fn next(&mut self) -> Option<Self::Item> {
290 if self.buf.len() == self.pos {
291 return None;
292 }
293 let res = unsafe {
294 InOut {
295 in_ptr: self.buf.in_ptr.add(self.pos),
296 out_ptr: self.buf.out_ptr.add(self.pos),
297 _pd: PhantomData,
298 }
299 };
300 self.pos += 1;
301 Some(res)
302 }
303}