der/reader/
slice.rs
1use crate::{BytesRef, Decode, Error, ErrorKind, Header, Length, Reader, Result, Tag};
4
5#[derive(Clone, Debug)]
7pub struct SliceReader<'a> {
8 bytes: BytesRef<'a>,
10
11 failed: bool,
13
14 position: Length,
16}
17
18impl<'a> SliceReader<'a> {
19 pub fn new(bytes: &'a [u8]) -> Result<Self> {
21 Ok(Self {
22 bytes: BytesRef::new(bytes)?,
23 failed: false,
24 position: Length::ZERO,
25 })
26 }
27
28 pub fn error(&mut self, kind: ErrorKind) -> Error {
31 self.failed = true;
32 kind.at(self.position)
33 }
34
35 pub fn value_error(&mut self, tag: Tag) -> Error {
37 self.error(tag.value_error().kind())
38 }
39
40 pub fn is_failed(&self) -> bool {
42 self.failed
43 }
44
45 fn remaining(&self) -> Result<&'a [u8]> {
48 if self.is_failed() {
49 Err(ErrorKind::Failed.at(self.position))
50 } else {
51 self.bytes
52 .as_slice()
53 .get(self.position.try_into()?..)
54 .ok_or_else(|| Error::incomplete(self.input_len()))
55 }
56 }
57}
58
59impl<'a> Reader<'a> for SliceReader<'a> {
60 fn input_len(&self) -> Length {
61 self.bytes.len()
62 }
63
64 fn peek_byte(&self) -> Option<u8> {
65 self.remaining()
66 .ok()
67 .and_then(|bytes| bytes.first().cloned())
68 }
69
70 fn peek_header(&self) -> Result<Header> {
71 Header::decode(&mut self.clone())
72 }
73
74 fn position(&self) -> Length {
75 self.position
76 }
77
78 fn read_slice(&mut self, len: Length) -> Result<&'a [u8]> {
79 if self.is_failed() {
80 return Err(self.error(ErrorKind::Failed));
81 }
82
83 match self.remaining()?.get(..len.try_into()?) {
84 Some(result) => {
85 self.position = (self.position + len)?;
86 Ok(result)
87 }
88 None => Err(self.error(ErrorKind::Incomplete {
89 expected_len: (self.position + len)?,
90 actual_len: self.input_len(),
91 })),
92 }
93 }
94
95 fn decode<T: Decode<'a>>(&mut self) -> Result<T> {
96 if self.is_failed() {
97 return Err(self.error(ErrorKind::Failed));
98 }
99
100 T::decode(self).map_err(|e| {
101 self.failed = true;
102 e.nested(self.position)
103 })
104 }
105
106 fn error(&mut self, kind: ErrorKind) -> Error {
107 self.failed = true;
108 kind.at(self.position)
109 }
110
111 fn finish<T>(self, value: T) -> Result<T> {
112 if self.is_failed() {
113 Err(ErrorKind::Failed.at(self.position))
114 } else if !self.is_finished() {
115 Err(ErrorKind::TrailingData {
116 decoded: self.position,
117 remaining: self.remaining_len(),
118 }
119 .at(self.position))
120 } else {
121 Ok(value)
122 }
123 }
124
125 fn remaining_len(&self) -> Length {
126 debug_assert!(self.position <= self.input_len());
127 self.input_len().saturating_sub(self.position)
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::SliceReader;
134 use crate::{Decode, ErrorKind, Length, Reader, Tag};
135 use hex_literal::hex;
136
137 const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
139
140 #[test]
141 fn empty_message() {
142 let mut reader = SliceReader::new(&[]).unwrap();
143 let err = bool::decode(&mut reader).err().unwrap();
144 assert_eq!(Some(Length::ZERO), err.position());
145
146 match err.kind() {
147 ErrorKind::Incomplete {
148 expected_len,
149 actual_len,
150 } => {
151 assert_eq!(actual_len, 0u8.into());
152 assert_eq!(expected_len, 1u8.into());
153 }
154 other => panic!("unexpected error kind: {:?}", other),
155 }
156 }
157
158 #[test]
159 fn invalid_field_length() {
160 const MSG_LEN: usize = 2;
161
162 let mut reader = SliceReader::new(&EXAMPLE_MSG[..MSG_LEN]).unwrap();
163 let err = i8::decode(&mut reader).err().unwrap();
164 assert_eq!(Some(Length::from(2u8)), err.position());
165
166 match err.kind() {
167 ErrorKind::Incomplete {
168 expected_len,
169 actual_len,
170 } => {
171 assert_eq!(actual_len, MSG_LEN.try_into().unwrap());
172 assert_eq!(expected_len, (MSG_LEN + 1).try_into().unwrap());
173 }
174 other => panic!("unexpected error kind: {:?}", other),
175 }
176 }
177
178 #[test]
179 fn trailing_data() {
180 let mut reader = SliceReader::new(EXAMPLE_MSG).unwrap();
181 let x = i8::decode(&mut reader).unwrap();
182 assert_eq!(42i8, x);
183
184 let err = reader.finish(x).err().unwrap();
185 assert_eq!(Some(Length::from(3u8)), err.position());
186
187 assert_eq!(
188 ErrorKind::TrailingData {
189 decoded: 3u8.into(),
190 remaining: 1u8.into()
191 },
192 err.kind()
193 );
194 }
195
196 #[test]
197 fn peek_tag() {
198 let reader = SliceReader::new(EXAMPLE_MSG).unwrap();
199 assert_eq!(reader.position(), Length::ZERO);
200 assert_eq!(reader.peek_tag().unwrap(), Tag::Integer);
201 assert_eq!(reader.position(), Length::ZERO); }
203
204 #[test]
205 fn peek_header() {
206 let reader = SliceReader::new(EXAMPLE_MSG).unwrap();
207 assert_eq!(reader.position(), Length::ZERO);
208
209 let header = reader.peek_header().unwrap();
210 assert_eq!(header.tag, Tag::Integer);
211 assert_eq!(header.length, Length::ONE);
212 assert_eq!(reader.position(), Length::ZERO); }
214}