der/reader/
slice.rs

1//! Slice reader.
2
3use crate::{BytesRef, Decode, Error, ErrorKind, Header, Length, Reader, Result, Tag};
4
5/// [`Reader`] which consumes an input byte slice.
6#[derive(Clone, Debug)]
7pub struct SliceReader<'a> {
8    /// Byte slice being decoded.
9    bytes: BytesRef<'a>,
10
11    /// Did the decoding operation fail?
12    failed: bool,
13
14    /// Position within the decoded slice.
15    position: Length,
16}
17
18impl<'a> SliceReader<'a> {
19    /// Create a new slice reader for the given byte slice.
20    pub fn new(bytes: &'a [u8]) -> Result<Self> {
21        Ok(Self {
22            bytes: BytesRef::new(bytes)?,
23            failed: false,
24            position: Length::ZERO,
25        })
26    }
27
28    /// Return an error with the given [`ErrorKind`], annotating it with
29    /// context about where the error occurred.
30    pub fn error(&mut self, kind: ErrorKind) -> Error {
31        self.failed = true;
32        kind.at(self.position)
33    }
34
35    /// Return an error for an invalid value with the given tag.
36    pub fn value_error(&mut self, tag: Tag) -> Error {
37        self.error(tag.value_error().kind())
38    }
39
40    /// Did the decoding operation fail due to an error?
41    pub fn is_failed(&self) -> bool {
42        self.failed
43    }
44
45    /// Obtain the remaining bytes in this slice reader from the current cursor
46    /// position.
47    fn remaining(&self) -> Result<&'a [u8]> {
48        if self.is_failed() {
49            Err(ErrorKind::Failed.at(self.position))
50        } else {
51            self.bytes
52                .as_slice()
53                .get(self.position.try_into()?..)
54                .ok_or_else(|| Error::incomplete(self.input_len()))
55        }
56    }
57}
58
59impl<'a> Reader<'a> for SliceReader<'a> {
60    fn input_len(&self) -> Length {
61        self.bytes.len()
62    }
63
64    fn peek_byte(&self) -> Option<u8> {
65        self.remaining()
66            .ok()
67            .and_then(|bytes| bytes.first().cloned())
68    }
69
70    fn peek_header(&self) -> Result<Header> {
71        Header::decode(&mut self.clone())
72    }
73
74    fn position(&self) -> Length {
75        self.position
76    }
77
78    fn read_slice(&mut self, len: Length) -> Result<&'a [u8]> {
79        if self.is_failed() {
80            return Err(self.error(ErrorKind::Failed));
81        }
82
83        match self.remaining()?.get(..len.try_into()?) {
84            Some(result) => {
85                self.position = (self.position + len)?;
86                Ok(result)
87            }
88            None => Err(self.error(ErrorKind::Incomplete {
89                expected_len: (self.position + len)?,
90                actual_len: self.input_len(),
91            })),
92        }
93    }
94
95    fn decode<T: Decode<'a>>(&mut self) -> Result<T> {
96        if self.is_failed() {
97            return Err(self.error(ErrorKind::Failed));
98        }
99
100        T::decode(self).map_err(|e| {
101            self.failed = true;
102            e.nested(self.position)
103        })
104    }
105
106    fn error(&mut self, kind: ErrorKind) -> Error {
107        self.failed = true;
108        kind.at(self.position)
109    }
110
111    fn finish<T>(self, value: T) -> Result<T> {
112        if self.is_failed() {
113            Err(ErrorKind::Failed.at(self.position))
114        } else if !self.is_finished() {
115            Err(ErrorKind::TrailingData {
116                decoded: self.position,
117                remaining: self.remaining_len(),
118            }
119            .at(self.position))
120        } else {
121            Ok(value)
122        }
123    }
124
125    fn remaining_len(&self) -> Length {
126        debug_assert!(self.position <= self.input_len());
127        self.input_len().saturating_sub(self.position)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::SliceReader;
134    use crate::{Decode, ErrorKind, Length, Reader, Tag};
135    use hex_literal::hex;
136
137    // INTEGER: 42
138    const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
139
140    #[test]
141    fn empty_message() {
142        let mut reader = SliceReader::new(&[]).unwrap();
143        let err = bool::decode(&mut reader).err().unwrap();
144        assert_eq!(Some(Length::ZERO), err.position());
145
146        match err.kind() {
147            ErrorKind::Incomplete {
148                expected_len,
149                actual_len,
150            } => {
151                assert_eq!(actual_len, 0u8.into());
152                assert_eq!(expected_len, 1u8.into());
153            }
154            other => panic!("unexpected error kind: {:?}", other),
155        }
156    }
157
158    #[test]
159    fn invalid_field_length() {
160        const MSG_LEN: usize = 2;
161
162        let mut reader = SliceReader::new(&EXAMPLE_MSG[..MSG_LEN]).unwrap();
163        let err = i8::decode(&mut reader).err().unwrap();
164        assert_eq!(Some(Length::from(2u8)), err.position());
165
166        match err.kind() {
167            ErrorKind::Incomplete {
168                expected_len,
169                actual_len,
170            } => {
171                assert_eq!(actual_len, MSG_LEN.try_into().unwrap());
172                assert_eq!(expected_len, (MSG_LEN + 1).try_into().unwrap());
173            }
174            other => panic!("unexpected error kind: {:?}", other),
175        }
176    }
177
178    #[test]
179    fn trailing_data() {
180        let mut reader = SliceReader::new(EXAMPLE_MSG).unwrap();
181        let x = i8::decode(&mut reader).unwrap();
182        assert_eq!(42i8, x);
183
184        let err = reader.finish(x).err().unwrap();
185        assert_eq!(Some(Length::from(3u8)), err.position());
186
187        assert_eq!(
188            ErrorKind::TrailingData {
189                decoded: 3u8.into(),
190                remaining: 1u8.into()
191            },
192            err.kind()
193        );
194    }
195
196    #[test]
197    fn peek_tag() {
198        let reader = SliceReader::new(EXAMPLE_MSG).unwrap();
199        assert_eq!(reader.position(), Length::ZERO);
200        assert_eq!(reader.peek_tag().unwrap(), Tag::Integer);
201        assert_eq!(reader.position(), Length::ZERO); // Position unchanged
202    }
203
204    #[test]
205    fn peek_header() {
206        let reader = SliceReader::new(EXAMPLE_MSG).unwrap();
207        assert_eq!(reader.position(), Length::ZERO);
208
209        let header = reader.peek_header().unwrap();
210        assert_eq!(header.tag, Tag::Integer);
211        assert_eq!(header.length, Length::ONE);
212        assert_eq!(reader.position(), Length::ZERO); // Position unchanged
213    }
214}