inout/
inout.rs

1use crate::InOutBuf;
2use core::{marker::PhantomData, ptr};
3use generic_array::{ArrayLength, GenericArray};
4
5/// Custom pointer type which contains one immutable (input) and one mutable
6/// (output) pointer, which are either equal or non-overlapping.
7pub struct InOut<'inp, 'out, T> {
8    pub(crate) in_ptr: *const T,
9    pub(crate) out_ptr: *mut T,
10    pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
11}
12
13impl<'inp, 'out, T> InOut<'inp, 'out, T> {
14    /// Reborrow `self`.
15    #[inline(always)]
16    pub fn reborrow<'a>(&'a mut self) -> InOut<'a, 'a, T> {
17        Self {
18            in_ptr: self.in_ptr,
19            out_ptr: self.out_ptr,
20            _pd: PhantomData,
21        }
22    }
23
24    /// Get immutable reference to the input value.
25    #[inline(always)]
26    pub fn get_in<'a>(&'a self) -> &'a T {
27        unsafe { &*self.in_ptr }
28    }
29
30    /// Get mutable reference to the output value.
31    #[inline(always)]
32    pub fn get_out<'a>(&'a mut self) -> &'a mut T {
33        unsafe { &mut *self.out_ptr }
34    }
35
36    /// Convert `self` to a pair of raw input and output pointers.
37    #[inline(always)]
38    pub fn into_raw(self) -> (*const T, *mut T) {
39        (self.in_ptr, self.out_ptr)
40    }
41
42    /// Create `InOut` from raw input and output pointers.
43    ///
44    /// # Safety
45    /// Behavior is undefined if any of the following conditions are violated:
46    /// - `in_ptr` must point to a properly initialized value of type `T` and
47    /// must be valid for reads.
48    /// - `out_ptr` must point to a properly initialized value of type `T` and
49    /// must be valid for both reads and writes.
50    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
51    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
52    /// them must not be accessed through any other pointer (not derived from
53    /// the return value) for the duration of lifetime 'a. Both read and write
54    /// accesses are forbidden.
55    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
56    /// `out_ptr` must not be accessed through any other pointer (not derived from
57    /// the return value) for the duration of lifetime `'a`. Both read and write
58    /// accesses are forbidden. The memory referenced by `in_ptr` must not be
59    /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
60    #[inline(always)]
61    pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
62        Self {
63            in_ptr,
64            out_ptr,
65            _pd: PhantomData,
66        }
67    }
68}
69
70impl<'inp, 'out, T: Clone> InOut<'inp, 'out, T> {
71    /// Clone input value and return it.
72    #[inline(always)]
73    pub fn clone_in(&self) -> T {
74        unsafe { (&*self.in_ptr).clone() }
75    }
76}
77
78impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
79    #[inline(always)]
80    fn from(val: &'a mut T) -> Self {
81        let p = val as *mut T;
82        Self {
83            in_ptr: p,
84            out_ptr: p,
85            _pd: PhantomData,
86        }
87    }
88}
89
90impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
91    #[inline(always)]
92    fn from((in_val, out_val): (&'inp T, &'out mut T)) -> Self {
93        Self {
94            in_ptr: in_val as *const T,
95            out_ptr: out_val as *mut T,
96            _pd: Default::default(),
97        }
98    }
99}
100
101impl<'inp, 'out, T, N: ArrayLength<T>> InOut<'inp, 'out, GenericArray<T, N>> {
102    /// Returns `InOut` for the given position.
103    ///
104    /// # Panics
105    /// If `pos` greater or equal to array length.
106    #[inline(always)]
107    pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
108        assert!(pos < N::USIZE);
109        unsafe {
110            InOut {
111                in_ptr: (self.in_ptr as *const T).add(pos),
112                out_ptr: (self.out_ptr as *mut T).add(pos),
113                _pd: PhantomData,
114            }
115        }
116    }
117
118    /// Convert `InOut` array to `InOutBuf`.
119    #[inline(always)]
120    pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
121        InOutBuf {
122            in_ptr: self.in_ptr as *const T,
123            out_ptr: self.out_ptr as *mut T,
124            len: N::USIZE,
125            _pd: PhantomData,
126        }
127    }
128}
129
130impl<'inp, 'out, N: ArrayLength<u8>> InOut<'inp, 'out, GenericArray<u8, N>> {
131    /// XOR `data` with values behind the input slice and write
132    /// result to the output slice.
133    ///
134    /// # Panics
135    /// If `data` length is not equal to the buffer length.
136    #[inline(always)]
137    #[allow(clippy::needless_range_loop)]
138    pub fn xor_in2out(&mut self, data: &GenericArray<u8, N>) {
139        unsafe {
140            let input = ptr::read(self.in_ptr);
141            let mut temp = GenericArray::<u8, N>::default();
142            for i in 0..N::USIZE {
143                temp[i] = input[i] ^ data[i];
144            }
145            ptr::write(self.out_ptr, temp);
146        }
147    }
148}
149
150impl<'inp, 'out, N, M> InOut<'inp, 'out, GenericArray<GenericArray<u8, N>, M>>
151where
152    N: ArrayLength<u8>,
153    M: ArrayLength<GenericArray<u8, N>>,
154{
155    /// XOR `data` with values behind the input slice and write
156    /// result to the output slice.
157    ///
158    /// # Panics
159    /// If `data` length is not equal to the buffer length.
160    #[inline(always)]
161    #[allow(clippy::needless_range_loop)]
162    pub fn xor_in2out(&mut self, data: &GenericArray<GenericArray<u8, N>, M>) {
163        unsafe {
164            let input = ptr::read(self.in_ptr);
165            let mut temp = GenericArray::<GenericArray<u8, N>, M>::default();
166            for i in 0..M::USIZE {
167                for j in 0..N::USIZE {
168                    temp[i][j] = input[i][j] ^ data[i][j];
169                }
170            }
171            ptr::write(self.out_ptr, temp);
172        }
173    }
174}