wlan_common/
buffer_reader.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::UnalignedView;
6use core::mem::size_of;
7use zerocopy::{
8    FromBytes, Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut, Unaligned,
9};
10
11/// Types that can be converted into a `BufferReader` over a `ByteSlice`.
12///
13/// Both `BufferReader` and `ByteSlice` types implement this trait and it can be used to accept
14/// both byte slices and readers.
15pub trait IntoBufferReader<B> {
16    fn into_buffer_reader(self) -> BufferReader<B>;
17}
18
19impl<B: SplitByteSlice> IntoBufferReader<B> for B {
20    fn into_buffer_reader(self) -> BufferReader<B> {
21        BufferReader::new(self)
22    }
23}
24
25pub struct BufferReader<B> {
26    buffer: Option<B>,
27    bytes_read: usize,
28}
29
30impl<B: SplitByteSlice> BufferReader<B> {
31    pub fn new(bytes: B) -> Self {
32        BufferReader { buffer: Some(bytes), bytes_read: 0 }
33    }
34
35    pub fn read<T>(&mut self) -> Option<Ref<B, T>>
36    where
37        T: Unaligned + KnownLayout + Immutable + FromBytes,
38    {
39        self.read_bytes(size_of::<T>()).map(|bytes| Ref::from_bytes(bytes).unwrap())
40    }
41
42    pub fn read_unaligned<T>(&mut self) -> Option<UnalignedView<B, T>>
43    where
44        T: FromBytes + Immutable,
45    {
46        self.read_bytes(size_of::<T>()).map(|bytes| UnalignedView::from_bytes(bytes).unwrap())
47    }
48
49    pub fn peek<T>(&self) -> Option<Ref<&[u8], T>>
50    where
51        T: Unaligned + Immutable + KnownLayout + FromBytes,
52    {
53        self.peek_bytes(size_of::<T>()).map(|bytes| Ref::from_bytes(bytes).unwrap())
54    }
55
56    pub fn peek_unaligned<T>(&self) -> Option<UnalignedView<&[u8], T>>
57    where
58        T: FromBytes + Immutable,
59    {
60        self.peek_bytes(size_of::<T>()).map(|bytes| UnalignedView::from_bytes(bytes).unwrap())
61    }
62
63    pub fn read_array<T>(&mut self, num_elems: usize) -> Option<Ref<B, [T]>>
64    where
65        T: Unaligned + FromBytes + Immutable,
66    {
67        self.read_bytes(size_of::<T>() * num_elems).map(|bytes| Ref::from_bytes(bytes).unwrap())
68    }
69
70    pub fn peek_array<T>(&self, num_elems: usize) -> Option<Ref<&[u8], [T]>>
71    where
72        T: Unaligned + FromBytes + Immutable,
73    {
74        self.peek_bytes(size_of::<T>() * num_elems).map(|bytes| Ref::from_bytes(bytes).unwrap())
75    }
76
77    pub fn read_byte(&mut self) -> Option<u8> {
78        self.read_bytes(1).map(|bytes| bytes[0])
79    }
80
81    pub fn peek_byte(&mut self) -> Option<u8> {
82        self.peek_bytes(1).map(|bytes| bytes[0])
83    }
84
85    /// Useful for reading integers.
86    ///
87    /// Example:
88    /// ```
89    /// let mut reader = BufferReader::new(&vec![1, 2, 3]);
90    /// let val = reader.read_value::<u16>();
91    /// assert_eq!(Some(1 + 256 * 2), val);
92    /// ```
93    pub fn read_value<T>(&mut self) -> Option<T>
94    where
95        T: FromBytes + Immutable + Copy,
96    {
97        self.read_unaligned::<T>().map(|view| view.get())
98    }
99
100    pub fn peek_value<T>(&self) -> Option<T>
101    where
102        T: FromBytes + Immutable + Copy,
103    {
104        self.peek_unaligned::<T>().map(|view| view.get())
105    }
106
107    pub fn read_bytes(&mut self, len: usize) -> Option<B> {
108        if self.buffer.as_ref().unwrap().len() >= len {
109            let (head, tail) = self.buffer.take().unwrap().split_at(len).ok().unwrap();
110            self.buffer = Some(tail);
111            self.bytes_read += len;
112            Some(head)
113        } else {
114            None
115        }
116    }
117
118    pub fn peek_bytes(&self, len: usize) -> Option<&[u8]> {
119        let buffer = self.buffer.as_ref().unwrap();
120        if buffer.len() >= len {
121            Some(&buffer[0..len])
122        } else {
123            None
124        }
125    }
126
127    pub fn peek_remaining(&self) -> &[u8] {
128        &self.buffer.as_ref().unwrap()[..]
129    }
130
131    pub fn bytes_read(&self) -> usize {
132        self.bytes_read
133    }
134
135    pub fn bytes_remaining(&self) -> usize {
136        self.buffer.as_ref().unwrap().len()
137    }
138
139    pub fn into_remaining(self) -> B {
140        self.buffer.unwrap()
141    }
142}
143
144impl<B: SplitByteSliceMut> BufferReader<B> {
145    pub fn peek_mut<T>(&mut self) -> Option<Ref<&mut [u8], T>>
146    where
147        T: Unaligned + FromBytes + Immutable + KnownLayout,
148    {
149        self.peek_bytes_mut(size_of::<T>()).map(|bytes| Ref::from_bytes(bytes).unwrap())
150    }
151
152    pub fn peek_mut_unaligned<T>(&mut self) -> Option<UnalignedView<&mut [u8], T>>
153    where
154        T: FromBytes + KnownLayout + Immutable,
155    {
156        self.peek_bytes_mut(size_of::<T>()).map(|bytes| UnalignedView::from_bytes(bytes).unwrap())
157    }
158
159    pub fn peek_array_mut<T>(&mut self, num_elems: usize) -> Option<Ref<&mut [u8], [T]>>
160    where
161        T: Unaligned + FromBytes + Immutable,
162    {
163        self.peek_bytes_mut(size_of::<T>() * num_elems).map(|bytes| Ref::from_bytes(bytes).unwrap())
164    }
165
166    pub fn peek_bytes_mut(&mut self, len: usize) -> Option<&mut [u8]> {
167        let buffer = self.buffer.as_mut().unwrap();
168        if buffer.len() >= len {
169            Some(&mut buffer[0..len])
170        } else {
171            None
172        }
173    }
174
175    pub fn peek_remaining_mut(&mut self) -> &mut [u8] {
176        &mut self.buffer.as_mut().unwrap()[..]
177    }
178}
179
180impl<B: SplitByteSlice> IntoBufferReader<B> for BufferReader<B> {
181    fn into_buffer_reader(self) -> BufferReader<B> {
182        self
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use zerocopy::{Immutable, IntoBytes};
190
191    #[repr(C, packed)]
192    #[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned)]
193    struct Foo {
194        x: u8,
195        y: u16,
196    }
197
198    #[test]
199    pub fn read() {
200        let mut data = vec![1u8, 2, 3, 4, 5, 6, 7];
201        let mut reader = BufferReader::new(&mut data[..]);
202        let foo = reader.read::<Foo>().expect("expected a Foo to be read");
203        assert_eq!(1, foo.x);
204        let y = foo.y;
205        assert_eq!(2 + 3 * 256, y); // assuming little endian
206        assert_eq!(3, reader.bytes_read());
207        assert_eq!(4, reader.bytes_remaining());
208
209        let bytes = reader.read_bytes(2).expect("expected 2 bytes to be read");
210        assert_eq!(&[4, 5], bytes);
211        assert_eq!(5, reader.bytes_read());
212        assert_eq!(2, reader.bytes_remaining());
213
214        assert!(reader.read::<Foo>().is_none());
215
216        let rest = reader.into_remaining();
217        assert_eq!(&[6, 7], rest);
218    }
219
220    #[test]
221    pub fn peek() {
222        let mut data = vec![1u8, 2, 3, 4, 5, 6, 7];
223        let mut reader = BufferReader::new(&mut data[..]);
224
225        let foo = reader.peek::<Foo>().expect("expected a Foo (1)");
226        assert_eq!(1, foo.x);
227
228        let foo = reader.peek::<Foo>().expect("expected a Foo (2)");
229        assert_eq!(1, foo.x);
230
231        assert_eq!(0, reader.bytes_read());
232        assert_eq!(7, reader.bytes_remaining());
233
234        reader.read_bytes(5);
235
236        let bytes = reader.peek_bytes(2).expect("expected a slice of 2 bytes");
237        assert_eq!(&[6, 7], bytes);
238
239        assert!(reader.peek_bytes(3).is_none());
240
241        assert_eq!(&[6, 7], reader.peek_remaining());
242    }
243
244    #[test]
245    pub fn peek_mut() {
246        let mut data = vec![1u8, 2, 3, 4, 5, 6, 7];
247        let mut reader = BufferReader::new(&mut data[..]);
248
249        let foo = reader.peek_mut::<Foo>().expect("expected a Foo (1)");
250        assert_eq!(1, foo.x);
251
252        let foo = reader.peek_mut::<Foo>().expect("expected a Foo (2)");
253        assert_eq!(1, foo.x);
254
255        assert_eq!(0, reader.bytes_read());
256        assert_eq!(7, reader.bytes_remaining());
257
258        reader.read_bytes(5);
259
260        let bytes = reader.peek_bytes_mut(2).expect("expected a slice of 2 bytes");
261        assert_eq!(&[6, 7], bytes);
262
263        assert!(reader.peek_bytes_mut(3).is_none());
264
265        assert_eq!(&[6, 7], reader.peek_remaining_mut());
266    }
267
268    #[test]
269    pub fn peek_and_read_value() {
270        let mut data = vec![1u8, 2, 3, 4];
271        let mut reader = BufferReader::new(&mut data[..]);
272
273        assert_eq!(Some(1), reader.peek_byte());
274        assert_eq!(0, reader.bytes_read());
275        assert_eq!(4, reader.bytes_remaining());
276
277        assert_eq!(Some(1), reader.read_byte());
278        assert_eq!(1, reader.bytes_read());
279        assert_eq!(3, reader.bytes_remaining());
280
281        assert_eq!(Some(2 + 256 * 3), reader.peek_value::<u16>()); // assuming little endian
282        assert_eq!(1, reader.bytes_read());
283        assert_eq!(3, reader.bytes_remaining());
284
285        assert_eq!(Some(2 + 256 * 3), reader.read_value::<u16>()); // assuming little endian
286        assert_eq!(3, reader.bytes_read());
287        assert_eq!(1, reader.bytes_remaining());
288
289        assert_eq!(None, reader.peek_value::<u16>());
290        assert_eq!(None, reader.read_value::<u16>());
291        assert_eq!(3, reader.bytes_read());
292        assert_eq!(1, reader.bytes_remaining());
293    }
294
295    #[test]
296    pub fn peek_and_read_array() {
297        let mut data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
298        let mut reader = BufferReader::new(&mut data[..]);
299
300        let arr = reader.peek_array::<Foo>(2).expect("expected peek() to return Some");
301        assert_eq!(2, arr.len());
302        assert_eq!(1, arr[0].x);
303        assert_eq!(4, arr[1].x);
304
305        assert_eq!(0, reader.bytes_read());
306        assert_eq!(8, reader.bytes_remaining());
307
308        let arr = reader.peek_array_mut::<Foo>(2).expect("expected peek() to return Some");
309        assert_eq!(2, arr.len());
310        assert_eq!(1, arr[0].x);
311        assert_eq!(4, arr[1].x);
312
313        assert_eq!(0, reader.bytes_read());
314        assert_eq!(8, reader.bytes_remaining());
315
316        let arr = reader.read_array::<Foo>(2).expect("expected peek() to return Some");
317        assert_eq!(2, arr.len());
318        assert_eq!(1, arr[0].x);
319        assert_eq!(4, arr[1].x);
320
321        assert_eq!(6, reader.bytes_read());
322        assert_eq!(2, reader.bytes_remaining());
323
324        assert!(reader.peek_array::<Foo>(1).is_none());
325        assert!(reader.read_array::<Foo>(1).is_none());
326        assert_eq!(6, reader.bytes_read());
327        assert_eq!(2, reader.bytes_remaining());
328    }
329
330    #[test]
331    pub fn peek_mut_and_modify() {
332        let mut data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
333        let mut reader = BufferReader::new(&mut data[..]);
334
335        let mut foo = reader.peek_mut::<Foo>().expect("expected peek() to return Some");
336        foo.y = 0xaabb;
337
338        let foo = reader.read::<Foo>().expect("expected read() to return Some");
339        let y = foo.y;
340        assert_eq!(0xaabb, y);
341        assert_eq!(0xbb, data[1]);
342        assert_eq!(0xaa, data[2]);
343    }
344
345    #[test]
346    pub fn unaligned_access() {
347        let mut data = vec![1u8, 2, 3, 4, 5, 6];
348        let mut reader = BufferReader::new(&mut data[..]);
349
350        reader.read_byte().expect("expected read_byte to return Ok");
351
352        let mut number =
353            reader.peek_mut_unaligned::<u32>().expect("expected peek_mut_unaligned to return Ok");
354        assert_eq!(0x05040302, number.get());
355        number.set(0x0a090807);
356        assert_eq!(1, reader.bytes_read());
357
358        let number = reader.peek_unaligned::<u32>().expect("expected peek_unaligned to return Ok");
359        assert_eq!(0x0a090807, number.get());
360        assert_eq!(1, reader.bytes_read());
361
362        let number = reader.read_unaligned::<u32>().expect("expected read_unaligned to return Ok");
363        assert_eq!(0x0a090807, number.get());
364        assert_eq!(5, reader.bytes_read());
365
366        assert_eq!(&[1, 7, 8, 9, 10, 6], &data[..]);
367    }
368}