inout/
reserved.rs

1use crate::errors::OutIsTooSmallError;
2use core::{marker::PhantomData, slice};
3
4#[cfg(feature = "block-padding")]
5use crate::errors::PadError;
6#[cfg(feature = "block-padding")]
7use crate::{InOut, InOutBuf};
8#[cfg(feature = "block-padding")]
9use block_padding::{PadType, Padding};
10#[cfg(feature = "block-padding")]
11use generic_array::{ArrayLength, GenericArray};
12
13/// Custom slice type which references one immutable (input) slice and one
14/// mutable (output) slice. Input and output slices are either the same or
15/// do not overlap. Length of the output slice is always equal or bigger than
16/// length of the input slice.
17pub struct InOutBufReserved<'inp, 'out, T> {
18    in_ptr: *const T,
19    out_ptr: *mut T,
20    in_len: usize,
21    out_len: usize,
22    _pd: PhantomData<(&'inp T, &'out mut T)>,
23}
24
25impl<'a, T> InOutBufReserved<'a, 'a, T> {
26    /// Crate [`InOutBufReserved`] from a single mutable slice.
27    pub fn from_mut_slice(buf: &'a mut [T], msg_len: usize) -> Result<Self, OutIsTooSmallError> {
28        if msg_len > buf.len() {
29            return Err(OutIsTooSmallError);
30        }
31        let p = buf.as_mut_ptr();
32        let out_len = buf.len();
33        Ok(Self {
34            in_ptr: p,
35            out_ptr: p,
36            in_len: msg_len,
37            out_len,
38            _pd: PhantomData,
39        })
40    }
41
42    /// Create [`InOutBufReserved`] from raw input and output pointers.
43    ///
44    /// # Safety
45    /// Behavior is undefined if any of the following conditions are violated:
46    /// - `in_ptr` must point to a properly initialized value of type `T` and
47    /// must be valid for reads for `in_len * mem::size_of::<T>()` many bytes.
48    /// - `out_ptr` must point to a properly initialized value of type `T` and
49    /// must be valid for both reads and writes for `out_len * mem::size_of::<T>()`
50    /// many bytes.
51    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
52    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
53    /// them must not be accessed through any other pointer (not derived from
54    /// the return value) for the duration of lifetime 'a. Both read and write
55    /// accesses are forbidden.
56    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
57    /// `out_ptr` must not be accessed through any other pointer (not derived from
58    /// the return value) for the duration of lifetime 'a. Both read and write
59    /// accesses are forbidden. The memory referenced by `in_ptr` must not be
60    /// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
61    /// - The total size `in_len * mem::size_of::<T>()` and
62    /// `out_len * mem::size_of::<T>()`  must be no larger than `isize::MAX`.
63    #[inline(always)]
64    pub unsafe fn from_raw(
65        in_ptr: *const T,
66        in_len: usize,
67        out_ptr: *mut T,
68        out_len: usize,
69    ) -> Self {
70        Self {
71            in_ptr,
72            out_ptr,
73            in_len,
74            out_len,
75            _pd: PhantomData,
76        }
77    }
78
79    /// Get raw input and output pointers.
80    #[inline(always)]
81    pub fn into_raw(self) -> (*const T, *mut T) {
82        (self.in_ptr, self.out_ptr)
83    }
84
85    /// Get input buffer length.
86    #[inline(always)]
87    pub fn get_in_len(&self) -> usize {
88        self.in_len
89    }
90
91    /// Get output buffer length.
92    #[inline(always)]
93    pub fn get_out_len(&self) -> usize {
94        self.in_len
95    }
96}
97
98impl<'inp, 'out, T> InOutBufReserved<'inp, 'out, T> {
99    /// Crate [`InOutBufReserved`] from two separate slices.
100    pub fn from_slices(
101        in_buf: &'inp [T],
102        out_buf: &'out mut [T],
103    ) -> Result<Self, OutIsTooSmallError> {
104        if in_buf.len() > out_buf.len() {
105            return Err(OutIsTooSmallError);
106        }
107        Ok(Self {
108            in_ptr: in_buf.as_ptr(),
109            out_ptr: out_buf.as_mut_ptr(),
110            in_len: in_buf.len(),
111            out_len: out_buf.len(),
112            _pd: PhantomData,
113        })
114    }
115
116    /// Get input slice.
117    #[inline(always)]
118    pub fn get_in<'a>(&'a self) -> &'a [T] {
119        unsafe { slice::from_raw_parts(self.in_ptr, self.in_len) }
120    }
121
122    /// Get output slice.
123    #[inline(always)]
124    pub fn get_out<'a>(&'a mut self) -> &'a mut [T] {
125        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.out_len) }
126    }
127}
128
129impl<'inp, 'out> InOutBufReserved<'inp, 'out, u8> {
130    /// Transform buffer into [`PaddedInOutBuf`] using padding algorithm `P`.
131    #[cfg(feature = "block-padding")]
132    #[cfg_attr(docsrs, doc(cfg(feature = "block-padding")))]
133    #[inline(always)]
134    pub fn into_padded_blocks<P, BS>(self) -> Result<PaddedInOutBuf<'inp, 'out, BS>, PadError>
135    where
136        P: Padding<BS>,
137        BS: ArrayLength<u8>,
138    {
139        let bs = BS::USIZE;
140        let blocks_len = self.in_len / bs;
141        let tail_len = self.in_len - bs * blocks_len;
142        let blocks = unsafe {
143            InOutBuf::from_raw(
144                self.in_ptr as *const GenericArray<u8, BS>,
145                self.out_ptr as *mut GenericArray<u8, BS>,
146                blocks_len,
147            )
148        };
149        let mut tail_in = GenericArray::<u8, BS>::default();
150        let tail_out = match P::TYPE {
151            PadType::NoPadding | PadType::Ambiguous if tail_len == 0 => None,
152            PadType::NoPadding => return Err(PadError),
153            PadType::Reversible | PadType::Ambiguous => {
154                let blen = bs * blocks_len;
155                let res_len = blen + bs;
156                if res_len > self.out_len {
157                    return Err(PadError);
158                }
159                // SAFETY: `in_ptr + blen..in_ptr + blen + tail_len`
160                // is valid region for reads and `tail_len` is smaller than `BS`.
161                // we have verified that `blen + bs <= out_len`, in other words,
162                // `out_ptr + blen..out_ptr + blen + bs` is valid region
163                // for writes.
164                let out_block = unsafe {
165                    core::ptr::copy_nonoverlapping(
166                        self.in_ptr.add(blen),
167                        tail_in.as_mut_ptr(),
168                        tail_len,
169                    );
170                    &mut *(self.out_ptr.add(blen) as *mut GenericArray<u8, BS>)
171                };
172                P::pad(&mut tail_in, tail_len);
173                Some(out_block)
174            }
175        };
176        Ok(PaddedInOutBuf {
177            blocks,
178            tail_in,
179            tail_out,
180        })
181    }
182}
183
184/// Variant of [`InOutBuf`] with optional padded tail block.
185#[cfg(feature = "block-padding")]
186#[cfg_attr(docsrs, doc(cfg(feature = "block-padding")))]
187pub struct PaddedInOutBuf<'inp, 'out, BS: ArrayLength<u8>> {
188    blocks: InOutBuf<'inp, 'out, GenericArray<u8, BS>>,
189    tail_in: GenericArray<u8, BS>,
190    tail_out: Option<&'out mut GenericArray<u8, BS>>,
191}
192
193#[cfg(feature = "block-padding")]
194impl<'inp, 'out, BS: ArrayLength<u8>> PaddedInOutBuf<'inp, 'out, BS> {
195    /// Get full blocks.
196    #[inline(always)]
197    pub fn get_blocks<'a>(&'a mut self) -> InOutBuf<'a, 'a, GenericArray<u8, BS>> {
198        self.blocks.reborrow()
199    }
200
201    /// Get padded tail block.
202    ///
203    /// For paddings with `P::TYPE = PadType::Reversible` it always returns `Some`.
204    #[inline(always)]
205    pub fn get_tail_block<'a>(&'a mut self) -> Option<InOut<'a, 'a, GenericArray<u8, BS>>> {
206        match self.tail_out.as_deref_mut() {
207            Some(out_block) => Some((&self.tail_in, out_block).into()),
208            None => None,
209        }
210    }
211
212    /// Convert buffer into output slice.
213    #[inline(always)]
214    pub fn into_out(self) -> &'out [u8] {
215        let total_blocks = if self.tail_out.is_some() {
216            self.blocks.len() + 1
217        } else {
218            self.blocks.len()
219        };
220        let res_len = BS::USIZE * total_blocks;
221        let (_, out_ptr) = self.blocks.into_raw();
222        // SAFETY: `res_len` is always valid for the output buffer since
223        // it's checked during type construction
224        unsafe { slice::from_raw_parts(out_ptr as *const u8, res_len) }
225    }
226}