1use crate::{
4 encoding,
5 line_ending::{CHAR_CR, CHAR_LF},
6 Encoding,
7 Error::{self, InvalidLength},
8 MIN_LINE_WIDTH,
9};
10use core::{cmp, marker::PhantomData};
11
12#[cfg(feature = "alloc")]
13use {alloc::vec::Vec, core::iter};
14
15#[cfg(feature = "std")]
16use std::io;
17
18#[cfg(doc)]
19use crate::{Base64, Base64Unpadded};
20
21#[derive(Clone)]
26pub struct Decoder<'i, E: Encoding> {
27 line: Line<'i>,
29
30 line_reader: LineReader<'i>,
32
33 remaining_len: usize,
35
36 block_buffer: BlockBuffer,
38
39 encoding: PhantomData<E>,
41}
42
43impl<'i, E: Encoding> Decoder<'i, E> {
44 pub fn new(input: &'i [u8]) -> Result<Self, Error> {
51 let line_reader = LineReader::new_unwrapped(input)?;
52 let remaining_len = line_reader.decoded_len::<E>()?;
53
54 Ok(Self {
55 line: Line::default(),
56 line_reader,
57 remaining_len,
58 block_buffer: BlockBuffer::default(),
59 encoding: PhantomData,
60 })
61 }
62
63 pub fn new_wrapped(input: &'i [u8], line_width: usize) -> Result<Self, Error> {
88 let line_reader = LineReader::new_wrapped(input, line_width)?;
89 let remaining_len = line_reader.decoded_len::<E>()?;
90
91 Ok(Self {
92 line: Line::default(),
93 line_reader,
94 remaining_len,
95 block_buffer: BlockBuffer::default(),
96 encoding: PhantomData,
97 })
98 }
99
100 pub fn decode<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8], Error> {
108 if self.is_finished() {
109 return Err(InvalidLength);
110 }
111
112 let mut out_pos = 0;
113
114 while out_pos < out.len() {
115 if !self.block_buffer.is_empty() {
117 let out_rem = out.len().checked_sub(out_pos).ok_or(InvalidLength)?;
118 let bytes = self.block_buffer.take(out_rem)?;
119 out[out_pos..][..bytes.len()].copy_from_slice(bytes);
120 out_pos = out_pos.checked_add(bytes.len()).ok_or(InvalidLength)?;
121 }
122
123 if self.line.is_empty() && !self.line_reader.is_empty() {
125 self.advance_line()?;
126 }
127
128 let in_blocks = self.line.len() / 4;
130 let out_rem = out.len().checked_sub(out_pos).ok_or(InvalidLength)?;
131 let out_blocks = out_rem / 3;
132 let blocks = cmp::min(in_blocks, out_blocks);
133 let in_aligned = self.line.take(blocks.checked_mul(4).ok_or(InvalidLength)?);
134
135 if !in_aligned.is_empty() {
136 let out_buf = &mut out[out_pos..][..blocks.checked_mul(3).ok_or(InvalidLength)?];
137 let decoded_len = self.perform_decode(in_aligned, out_buf)?.len();
138 out_pos = out_pos.checked_add(decoded_len).ok_or(InvalidLength)?;
139 }
140
141 if out_pos < out.len() {
142 if self.is_finished() {
143 return Err(InvalidLength);
146 } else {
147 self.fill_block_buffer()?;
152 }
153 }
154 }
155
156 self.remaining_len = self
157 .remaining_len
158 .checked_sub(out.len())
159 .ok_or(InvalidLength)?;
160
161 Ok(out)
162 }
163
164 #[cfg(feature = "alloc")]
169 pub fn decode_to_end<'o>(&mut self, buf: &'o mut Vec<u8>) -> Result<&'o [u8], Error> {
170 let start_len = buf.len();
171 let remaining_len = self.remaining_len();
172 let total_len = start_len.checked_add(remaining_len).ok_or(InvalidLength)?;
173
174 if total_len > buf.capacity() {
175 buf.reserve(total_len.checked_sub(buf.capacity()).ok_or(InvalidLength)?);
176 }
177
178 buf.extend(iter::repeat(0).take(remaining_len));
180 self.decode(&mut buf[start_len..])?;
181 Ok(&buf[start_len..])
182 }
183
184 pub fn remaining_len(&self) -> usize {
188 self.remaining_len
189 }
190
191 pub fn is_finished(&self) -> bool {
193 self.line.is_empty() && self.line_reader.is_empty() && self.block_buffer.is_empty()
194 }
195
196 fn fill_block_buffer(&mut self) -> Result<(), Error> {
198 let mut buf = [0u8; BlockBuffer::SIZE];
199
200 let decoded = if self.line.len() < 4 && !self.line_reader.is_empty() {
201 let mut tmp = [0u8; 4];
203
204 let line_end = self.line.take(4);
206 tmp[..line_end.len()].copy_from_slice(line_end);
207
208 self.advance_line()?;
210 let len = 4usize.checked_sub(line_end.len()).ok_or(InvalidLength)?;
211 let line_begin = self.line.take(len);
212 tmp[line_end.len()..][..line_begin.len()].copy_from_slice(line_begin);
213
214 let tmp_len = line_begin
215 .len()
216 .checked_add(line_end.len())
217 .ok_or(InvalidLength)?;
218
219 self.perform_decode(&tmp[..tmp_len], &mut buf)
220 } else {
221 let block = self.line.take(4);
222 self.perform_decode(block, &mut buf)
223 }?;
224
225 self.block_buffer.fill(decoded)
226 }
227
228 fn advance_line(&mut self) -> Result<(), Error> {
230 debug_assert!(self.line.is_empty(), "expected line buffer to be empty");
231
232 if let Some(line) = self.line_reader.next().transpose()? {
233 self.line = line;
234 Ok(())
235 } else {
236 Err(InvalidLength)
237 }
238 }
239
240 fn perform_decode<'o>(&self, src: &[u8], dst: &'o mut [u8]) -> Result<&'o [u8], Error> {
242 if self.is_finished() {
243 E::decode(src, dst)
244 } else {
245 E::Unpadded::decode(src, dst)
246 }
247 }
248}
249
250#[cfg(feature = "std")]
251impl<'i, E: Encoding> io::Read for Decoder<'i, E> {
252 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
253 if self.is_finished() {
254 return Ok(0);
255 }
256 let slice = match buf.get_mut(..self.remaining_len()) {
257 Some(bytes) => bytes,
258 None => buf,
259 };
260
261 self.decode(slice)?;
262 Ok(slice.len())
263 }
264
265 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
266 if self.is_finished() {
267 return Ok(0);
268 }
269 Ok(self.decode_to_end(buf)?.len())
270 }
271
272 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
273 self.decode(buf)?;
274 Ok(())
275 }
276}
277
278#[derive(Clone, Default, Debug)]
283struct BlockBuffer {
284 decoded: [u8; Self::SIZE],
286
287 length: usize,
289
290 position: usize,
292}
293
294impl BlockBuffer {
295 const SIZE: usize = 3;
297
298 fn fill(&mut self, decoded_input: &[u8]) -> Result<(), Error> {
300 debug_assert!(self.is_empty());
301
302 if decoded_input.len() > Self::SIZE {
303 return Err(InvalidLength);
304 }
305
306 self.position = 0;
307 self.length = decoded_input.len();
308 self.decoded[..decoded_input.len()].copy_from_slice(decoded_input);
309 Ok(())
310 }
311
312 fn take(&mut self, mut nbytes: usize) -> Result<&[u8], Error> {
317 debug_assert!(self.position <= self.length);
318 let start_pos = self.position;
319 let remaining_len = self.length.checked_sub(start_pos).ok_or(InvalidLength)?;
320
321 if nbytes > remaining_len {
322 nbytes = remaining_len;
323 }
324
325 self.position = self.position.checked_add(nbytes).ok_or(InvalidLength)?;
326 Ok(&self.decoded[start_pos..][..nbytes])
327 }
328
329 fn is_empty(&self) -> bool {
331 self.position == self.length
332 }
333}
334
335#[derive(Clone, Debug)]
337pub struct Line<'i> {
338 remaining: &'i [u8],
340}
341
342impl<'i> Default for Line<'i> {
343 fn default() -> Self {
344 Self::new(&[])
345 }
346}
347
348impl<'i> Line<'i> {
349 fn new(bytes: &'i [u8]) -> Self {
351 Self { remaining: bytes }
352 }
353
354 fn take(&mut self, nbytes: usize) -> &'i [u8] {
356 let (bytes, rest) = if nbytes < self.remaining.len() {
357 self.remaining.split_at(nbytes)
358 } else {
359 (self.remaining, [].as_ref())
360 };
361
362 self.remaining = rest;
363 bytes
364 }
365
366 fn slice_tail(&self, nbytes: usize) -> Result<&'i [u8], Error> {
368 let offset = self.len().checked_sub(nbytes).ok_or(InvalidLength)?;
369 self.remaining.get(offset..).ok_or(InvalidLength)
370 }
371
372 fn len(&self) -> usize {
374 self.remaining.len()
375 }
376
377 fn is_empty(&self) -> bool {
379 self.len() == 0
380 }
381
382 fn trim_end(&self) -> Self {
384 Line::new(match self.remaining {
385 [line @ .., CHAR_CR, CHAR_LF] => line,
386 [line @ .., CHAR_CR] => line,
387 [line @ .., CHAR_LF] => line,
388 line => line,
389 })
390 }
391}
392
393#[derive(Clone)]
395struct LineReader<'i> {
396 remaining: &'i [u8],
398
399 line_width: Option<usize>,
401}
402
403impl<'i> LineReader<'i> {
404 fn new_unwrapped(bytes: &'i [u8]) -> Result<Self, Error> {
406 if bytes.is_empty() {
407 Err(InvalidLength)
408 } else {
409 Ok(Self {
410 remaining: bytes,
411 line_width: None,
412 })
413 }
414 }
415
416 fn new_wrapped(bytes: &'i [u8], line_width: usize) -> Result<Self, Error> {
418 if line_width < MIN_LINE_WIDTH {
419 return Err(InvalidLength);
420 }
421
422 let mut reader = Self::new_unwrapped(bytes)?;
423 reader.line_width = Some(line_width);
424 Ok(reader)
425 }
426
427 fn is_empty(&self) -> bool {
429 self.remaining.is_empty()
430 }
431
432 fn decoded_len<E: Encoding>(&self) -> Result<usize, Error> {
434 let mut buffer = [0u8; 4];
435 let mut lines = self.clone();
436 let mut line = match lines.next().transpose()? {
437 Some(l) => l,
438 None => return Ok(0),
439 };
440 let mut base64_len = 0usize;
441
442 loop {
443 base64_len = base64_len.checked_add(line.len()).ok_or(InvalidLength)?;
444
445 match lines.next().transpose()? {
446 Some(l) => {
447 buffer.copy_from_slice(line.slice_tail(4)?);
450
451 line = l
452 }
453
454 None => {
459 let base64_last_block_len = match base64_len % 4 {
461 0 => 4,
462 n => n,
463 };
464
465 let decoded_len = encoding::decoded_len(
467 base64_len
468 .checked_sub(base64_last_block_len)
469 .ok_or(InvalidLength)?,
470 );
471
472 let mut out = [0u8; 3];
474 let last_block_len = if line.len() < base64_last_block_len {
475 let buffered_part_len = base64_last_block_len
476 .checked_sub(line.len())
477 .ok_or(InvalidLength)?;
478
479 let offset = 4usize.checked_sub(buffered_part_len).ok_or(InvalidLength)?;
480
481 for i in 0..buffered_part_len {
482 buffer[i] = buffer[offset.checked_add(i).ok_or(InvalidLength)?];
483 }
484
485 buffer[buffered_part_len..][..line.len()].copy_from_slice(line.remaining);
486 let buffer_len = buffered_part_len
487 .checked_add(line.len())
488 .ok_or(InvalidLength)?;
489
490 E::decode(&buffer[..buffer_len], &mut out)?.len()
491 } else {
492 let last_block = line.slice_tail(base64_last_block_len)?;
493 E::decode(last_block, &mut out)?.len()
494 };
495
496 return decoded_len.checked_add(last_block_len).ok_or(InvalidLength);
497 }
498 }
499 }
500 }
501}
502
503impl<'i> Iterator for LineReader<'i> {
504 type Item = Result<Line<'i>, Error>;
505
506 fn next(&mut self) -> Option<Result<Line<'i>, Error>> {
507 if let Some(line_width) = self.line_width {
508 let rest = match self.remaining.get(line_width..) {
509 None | Some([]) => {
510 if self.remaining.is_empty() {
511 return None;
512 } else {
513 let line = Line::new(self.remaining).trim_end();
514 self.remaining = &[];
515 return Some(Ok(line));
516 }
517 }
518 Some([CHAR_CR, CHAR_LF, rest @ ..]) => rest,
519 Some([CHAR_CR, rest @ ..]) => rest,
520 Some([CHAR_LF, rest @ ..]) => rest,
521 _ => {
522 return Some(Err(Error::InvalidEncoding));
524 }
525 };
526
527 let line = Line::new(&self.remaining[..line_width]);
528 self.remaining = rest;
529 Some(Ok(line))
530 } else if !self.remaining.is_empty() {
531 let line = Line::new(self.remaining).trim_end();
532 self.remaining = b"";
533
534 if line.is_empty() {
535 None
536 } else {
537 Some(Ok(line))
538 }
539 } else {
540 None
541 }
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use crate::{alphabet::Alphabet, test_vectors::*, Base64, Base64Unpadded, Decoder};
548
549 #[cfg(feature = "std")]
550 use {alloc::vec::Vec, std::io::Read};
551
552 #[test]
553 fn decode_padded() {
554 decode_test(PADDED_BIN, || {
555 Decoder::<Base64>::new(PADDED_BASE64.as_bytes()).unwrap()
556 })
557 }
558
559 #[test]
560 fn decode_unpadded() {
561 decode_test(UNPADDED_BIN, || {
562 Decoder::<Base64Unpadded>::new(UNPADDED_BASE64.as_bytes()).unwrap()
563 })
564 }
565
566 #[test]
567 fn decode_multiline_padded() {
568 decode_test(MULTILINE_PADDED_BIN, || {
569 Decoder::<Base64>::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap()
570 })
571 }
572
573 #[test]
574 fn decode_multiline_unpadded() {
575 decode_test(MULTILINE_UNPADDED_BIN, || {
576 Decoder::<Base64Unpadded>::new_wrapped(MULTILINE_UNPADDED_BASE64.as_bytes(), 70)
577 .unwrap()
578 })
579 }
580
581 #[cfg(feature = "std")]
582 #[test]
583 fn read_multiline_padded() {
584 let mut decoder =
585 Decoder::<Base64>::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap();
586
587 let mut buf = Vec::new();
588 let len = decoder.read_to_end(&mut buf).unwrap();
589
590 assert_eq!(len, MULTILINE_PADDED_BIN.len());
591 assert_eq!(buf.as_slice(), MULTILINE_PADDED_BIN);
592 }
593
594 fn decode_test<'a, F, V>(expected: &[u8], f: F)
596 where
597 F: Fn() -> Decoder<'a, V>,
598 V: Alphabet,
599 {
600 for chunk_size in 1..expected.len() {
601 let mut decoder = f();
602 let mut remaining_len = decoder.remaining_len();
603 let mut buffer = [0u8; 1024];
604
605 for chunk in expected.chunks(chunk_size) {
606 assert!(!decoder.is_finished());
607 let decoded = decoder.decode(&mut buffer[..chunk.len()]).unwrap();
608 assert_eq!(chunk, decoded);
609
610 remaining_len -= decoded.len();
611 assert_eq!(remaining_len, decoder.remaining_len());
612 }
613
614 assert!(decoder.is_finished());
615 assert_eq!(decoder.remaining_len(), 0);
616 }
617 }
618}