base64ct/
encoder.rs

1//! Buffered Base64 encoder.
2
3use crate::{
4    Encoding,
5    Error::{self, InvalidLength},
6    LineEnding, MIN_LINE_WIDTH,
7};
8use core::{cmp, marker::PhantomData, str};
9
10#[cfg(feature = "std")]
11use std::io;
12
13#[cfg(doc)]
14use crate::{Base64, Base64Unpadded};
15
16/// Stateful Base64 encoder with support for buffered, incremental encoding.
17///
18/// The `E` type parameter can be any type which impls [`Encoding`] such as
19/// [`Base64`] or [`Base64Unpadded`].
20pub struct Encoder<'o, E: Encoding> {
21    /// Output buffer.
22    output: &'o mut [u8],
23
24    /// Cursor within the output buffer.
25    position: usize,
26
27    /// Block buffer used for non-block-aligned data.
28    block_buffer: BlockBuffer,
29
30    /// Configuration and state for line-wrapping the output at a specified
31    /// column.
32    line_wrapper: Option<LineWrapper>,
33
34    /// Phantom parameter for the Base64 encoding in use.
35    encoding: PhantomData<E>,
36}
37
38impl<'o, E: Encoding> Encoder<'o, E> {
39    /// Create a new encoder which writes output to the given byte slice.
40    ///
41    /// Output constructed using this method is not line-wrapped.
42    pub fn new(output: &'o mut [u8]) -> Result<Self, Error> {
43        if output.is_empty() {
44            return Err(InvalidLength);
45        }
46
47        Ok(Self {
48            output,
49            position: 0,
50            block_buffer: BlockBuffer::default(),
51            line_wrapper: None,
52            encoding: PhantomData,
53        })
54    }
55
56    /// Create a new encoder which writes line-wrapped output to the given byte
57    /// slice.
58    ///
59    /// Output will be wrapped at the specified interval, using the provided
60    /// line ending. Use [`LineEnding::default()`] to use the conventional line
61    /// ending for the target OS.
62    ///
63    /// Minimum allowed line width is 4.
64    pub fn new_wrapped(
65        output: &'o mut [u8],
66        width: usize,
67        ending: LineEnding,
68    ) -> Result<Self, Error> {
69        let mut encoder = Self::new(output)?;
70        encoder.line_wrapper = Some(LineWrapper::new(width, ending)?);
71        Ok(encoder)
72    }
73
74    /// Encode the provided buffer as Base64, writing it to the output buffer.
75    ///
76    /// # Returns
77    /// - `Ok(bytes)` if the expected amount of data was read
78    /// - `Err(Error::InvalidLength)` if there is insufficient space in the output buffer
79    pub fn encode(&mut self, mut input: &[u8]) -> Result<(), Error> {
80        // If there's data in the block buffer, fill it
81        if !self.block_buffer.is_empty() {
82            self.process_buffer(&mut input)?;
83        }
84
85        while !input.is_empty() {
86            // Attempt to encode a stride of block-aligned data
87            let in_blocks = input.len() / 3;
88            let out_blocks = self.remaining().len() / 4;
89            let mut blocks = cmp::min(in_blocks, out_blocks);
90
91            // When line wrapping, cap the block-aligned stride at near/at line length
92            if let Some(line_wrapper) = &self.line_wrapper {
93                line_wrapper.wrap_blocks(&mut blocks)?;
94            }
95
96            if blocks > 0 {
97                let len = blocks.checked_mul(3).ok_or(InvalidLength)?;
98                let (in_aligned, in_rem) = input.split_at(len);
99                input = in_rem;
100                self.perform_encode(in_aligned)?;
101            }
102
103            // If there's remaining non-aligned data, fill the block buffer
104            if !input.is_empty() {
105                self.process_buffer(&mut input)?;
106            }
107        }
108
109        Ok(())
110    }
111
112    /// Get the position inside of the output buffer where the write cursor
113    /// is currently located.
114    pub fn position(&self) -> usize {
115        self.position
116    }
117
118    /// Finish encoding data, returning the resulting Base64 as a `str`.
119    pub fn finish(self) -> Result<&'o str, Error> {
120        self.finish_with_remaining().map(|(base64, _)| base64)
121    }
122
123    /// Finish encoding data, returning the resulting Base64 as a `str`
124    /// along with the remaining space in the output buffer.
125    pub fn finish_with_remaining(mut self) -> Result<(&'o str, &'o mut [u8]), Error> {
126        if !self.block_buffer.is_empty() {
127            let buffer_len = self.block_buffer.position;
128            let block = self.block_buffer.bytes;
129            self.perform_encode(&block[..buffer_len])?;
130        }
131
132        let (base64, remaining) = self.output.split_at_mut(self.position);
133        Ok((str::from_utf8(base64)?, remaining))
134    }
135
136    /// Borrow the remaining data in the buffer.
137    fn remaining(&mut self) -> &mut [u8] {
138        &mut self.output[self.position..]
139    }
140
141    /// Fill the block buffer with data, consuming and encoding it when the
142    /// buffer is full.
143    fn process_buffer(&mut self, input: &mut &[u8]) -> Result<(), Error> {
144        self.block_buffer.fill(input)?;
145
146        if self.block_buffer.is_full() {
147            let block = self.block_buffer.take();
148            self.perform_encode(&block)?;
149        }
150
151        Ok(())
152    }
153
154    /// Perform Base64 encoding operation.
155    fn perform_encode(&mut self, input: &[u8]) -> Result<usize, Error> {
156        let mut len = E::encode(input, self.remaining())?.as_bytes().len();
157
158        // Insert newline characters into the output as needed
159        if let Some(line_wrapper) = &mut self.line_wrapper {
160            line_wrapper.insert_newlines(&mut self.output[self.position..], &mut len)?;
161        }
162
163        self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
164        Ok(len)
165    }
166}
167
168#[cfg(feature = "std")]
169impl<'o, E: Encoding> io::Write for Encoder<'o, E> {
170    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
171        self.encode(buf)?;
172        Ok(buf.len())
173    }
174
175    fn flush(&mut self) -> io::Result<()> {
176        // TODO(tarcieri): return an error if there's still data remaining in the buffer?
177        Ok(())
178    }
179}
180
181/// Base64 encode buffer for a 1-block output.
182///
183/// This handles a partial block of data, i.e. data which hasn't been
184#[derive(Clone, Default, Debug)]
185struct BlockBuffer {
186    /// 3 decoded bytes to be encoded to a 4-byte Base64-encoded input.
187    bytes: [u8; Self::SIZE],
188
189    /// Position within the buffer.
190    position: usize,
191}
192
193impl BlockBuffer {
194    /// Size of the buffer in bytes: 3-bytes of unencoded input which
195    /// Base64 encode to 4-bytes of output.
196    const SIZE: usize = 3;
197
198    /// Fill the remaining space in the buffer with the input data.
199    fn fill(&mut self, input: &mut &[u8]) -> Result<(), Error> {
200        let remaining = Self::SIZE.checked_sub(self.position).ok_or(InvalidLength)?;
201        let len = cmp::min(input.len(), remaining);
202        self.bytes[self.position..][..len].copy_from_slice(&input[..len]);
203        self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
204        *input = &input[len..];
205        Ok(())
206    }
207
208    /// Take the output buffer, resetting the position to 0.
209    fn take(&mut self) -> [u8; Self::SIZE] {
210        debug_assert!(self.is_full());
211        let result = self.bytes;
212        *self = Default::default();
213        result
214    }
215
216    /// Is the buffer empty?
217    fn is_empty(&self) -> bool {
218        self.position == 0
219    }
220
221    /// Is the buffer full?
222    fn is_full(&self) -> bool {
223        self.position == Self::SIZE
224    }
225}
226
227/// Helper for wrapping Base64 at a given line width.
228#[derive(Debug)]
229struct LineWrapper {
230    /// Number of bytes remaining in the current line.
231    remaining: usize,
232
233    /// Column at which Base64 should be wrapped.
234    width: usize,
235
236    /// Newline characters to use at the end of each line.
237    ending: LineEnding,
238}
239
240impl LineWrapper {
241    /// Create a new linewrapper.
242    fn new(width: usize, ending: LineEnding) -> Result<Self, Error> {
243        if width < MIN_LINE_WIDTH {
244            return Err(InvalidLength);
245        }
246
247        Ok(Self {
248            remaining: width,
249            width,
250            ending,
251        })
252    }
253
254    /// Wrap the number of blocks to encode near/at EOL.
255    fn wrap_blocks(&self, blocks: &mut usize) -> Result<(), Error> {
256        if blocks.checked_mul(4).ok_or(InvalidLength)? >= self.remaining {
257            *blocks = self.remaining / 4;
258        }
259
260        Ok(())
261    }
262
263    /// Insert newlines into the output buffer as needed.
264    fn insert_newlines(&mut self, mut buffer: &mut [u8], len: &mut usize) -> Result<(), Error> {
265        let mut buffer_len = *len;
266
267        if buffer_len <= self.remaining {
268            self.remaining = self
269                .remaining
270                .checked_sub(buffer_len)
271                .ok_or(InvalidLength)?;
272
273            return Ok(());
274        }
275
276        buffer = &mut buffer[self.remaining..];
277        buffer_len = buffer_len
278            .checked_sub(self.remaining)
279            .ok_or(InvalidLength)?;
280
281        // The `wrap_blocks` function should ensure the buffer is no larger than a Base64 block
282        debug_assert!(buffer_len <= 4, "buffer too long: {}", buffer_len);
283
284        // Ensure space in buffer to add newlines
285        let buffer_end = buffer_len
286            .checked_add(self.ending.len())
287            .ok_or(InvalidLength)?;
288
289        if buffer_end >= buffer.len() {
290            return Err(InvalidLength);
291        }
292
293        // Shift the buffer contents to make space for the line ending
294        for i in (0..buffer_len).rev() {
295            buffer[i.checked_add(self.ending.len()).ok_or(InvalidLength)?] = buffer[i];
296        }
297
298        buffer[..self.ending.len()].copy_from_slice(self.ending.as_bytes());
299        *len = (*len).checked_add(self.ending.len()).ok_or(InvalidLength)?;
300        self.remaining = self.width.checked_sub(buffer_len).ok_or(InvalidLength)?;
301
302        Ok(())
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use crate::{alphabet::Alphabet, test_vectors::*, Base64, Base64Unpadded, Encoder, LineEnding};
309
310    #[test]
311    fn encode_padded() {
312        encode_test::<Base64>(PADDED_BIN, PADDED_BASE64, None);
313    }
314
315    #[test]
316    fn encode_unpadded() {
317        encode_test::<Base64Unpadded>(UNPADDED_BIN, UNPADDED_BASE64, None);
318    }
319
320    #[test]
321    fn encode_multiline_padded() {
322        encode_test::<Base64>(MULTILINE_PADDED_BIN, MULTILINE_PADDED_BASE64, Some(70));
323    }
324
325    #[test]
326    fn encode_multiline_unpadded() {
327        encode_test::<Base64Unpadded>(MULTILINE_UNPADDED_BIN, MULTILINE_UNPADDED_BASE64, Some(70));
328    }
329
330    #[test]
331    fn no_trailing_newline_when_aligned() {
332        let mut buffer = [0u8; 64];
333        let mut encoder = Encoder::<Base64>::new_wrapped(&mut buffer, 64, LineEnding::LF).unwrap();
334        encoder.encode(&[0u8; 48]).unwrap();
335
336        // Ensure no newline character is present in this case
337        assert_eq!(
338            "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
339            encoder.finish().unwrap()
340        );
341    }
342
343    /// Core functionality of an encoding test.
344    fn encode_test<V: Alphabet>(input: &[u8], expected: &str, wrapped: Option<usize>) {
345        let mut buffer = [0u8; 1024];
346
347        for chunk_size in 1..input.len() {
348            let mut encoder = match wrapped {
349                Some(line_width) => {
350                    Encoder::<V>::new_wrapped(&mut buffer, line_width, LineEnding::LF)
351                }
352                None => Encoder::<V>::new(&mut buffer),
353            }
354            .unwrap();
355
356            for chunk in input.chunks(chunk_size) {
357                encoder.encode(chunk).unwrap();
358            }
359
360            assert_eq!(expected, encoder.finish().unwrap());
361        }
362    }
363}