1use crate::{
4 Encoding,
5 Error::{self, InvalidLength},
6 LineEnding, MIN_LINE_WIDTH,
7};
8use core::{cmp, marker::PhantomData, str};
9
10#[cfg(feature = "std")]
11use std::io;
12
13#[cfg(doc)]
14use crate::{Base64, Base64Unpadded};
15
16pub struct Encoder<'o, E: Encoding> {
21 output: &'o mut [u8],
23
24 position: usize,
26
27 block_buffer: BlockBuffer,
29
30 line_wrapper: Option<LineWrapper>,
33
34 encoding: PhantomData<E>,
36}
37
38impl<'o, E: Encoding> Encoder<'o, E> {
39 pub fn new(output: &'o mut [u8]) -> Result<Self, Error> {
43 if output.is_empty() {
44 return Err(InvalidLength);
45 }
46
47 Ok(Self {
48 output,
49 position: 0,
50 block_buffer: BlockBuffer::default(),
51 line_wrapper: None,
52 encoding: PhantomData,
53 })
54 }
55
56 pub fn new_wrapped(
65 output: &'o mut [u8],
66 width: usize,
67 ending: LineEnding,
68 ) -> Result<Self, Error> {
69 let mut encoder = Self::new(output)?;
70 encoder.line_wrapper = Some(LineWrapper::new(width, ending)?);
71 Ok(encoder)
72 }
73
74 pub fn encode(&mut self, mut input: &[u8]) -> Result<(), Error> {
80 if !self.block_buffer.is_empty() {
82 self.process_buffer(&mut input)?;
83 }
84
85 while !input.is_empty() {
86 let in_blocks = input.len() / 3;
88 let out_blocks = self.remaining().len() / 4;
89 let mut blocks = cmp::min(in_blocks, out_blocks);
90
91 if let Some(line_wrapper) = &self.line_wrapper {
93 line_wrapper.wrap_blocks(&mut blocks)?;
94 }
95
96 if blocks > 0 {
97 let len = blocks.checked_mul(3).ok_or(InvalidLength)?;
98 let (in_aligned, in_rem) = input.split_at(len);
99 input = in_rem;
100 self.perform_encode(in_aligned)?;
101 }
102
103 if !input.is_empty() {
105 self.process_buffer(&mut input)?;
106 }
107 }
108
109 Ok(())
110 }
111
112 pub fn position(&self) -> usize {
115 self.position
116 }
117
118 pub fn finish(self) -> Result<&'o str, Error> {
120 self.finish_with_remaining().map(|(base64, _)| base64)
121 }
122
123 pub fn finish_with_remaining(mut self) -> Result<(&'o str, &'o mut [u8]), Error> {
126 if !self.block_buffer.is_empty() {
127 let buffer_len = self.block_buffer.position;
128 let block = self.block_buffer.bytes;
129 self.perform_encode(&block[..buffer_len])?;
130 }
131
132 let (base64, remaining) = self.output.split_at_mut(self.position);
133 Ok((str::from_utf8(base64)?, remaining))
134 }
135
136 fn remaining(&mut self) -> &mut [u8] {
138 &mut self.output[self.position..]
139 }
140
141 fn process_buffer(&mut self, input: &mut &[u8]) -> Result<(), Error> {
144 self.block_buffer.fill(input)?;
145
146 if self.block_buffer.is_full() {
147 let block = self.block_buffer.take();
148 self.perform_encode(&block)?;
149 }
150
151 Ok(())
152 }
153
154 fn perform_encode(&mut self, input: &[u8]) -> Result<usize, Error> {
156 let mut len = E::encode(input, self.remaining())?.as_bytes().len();
157
158 if let Some(line_wrapper) = &mut self.line_wrapper {
160 line_wrapper.insert_newlines(&mut self.output[self.position..], &mut len)?;
161 }
162
163 self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
164 Ok(len)
165 }
166}
167
168#[cfg(feature = "std")]
169impl<'o, E: Encoding> io::Write for Encoder<'o, E> {
170 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
171 self.encode(buf)?;
172 Ok(buf.len())
173 }
174
175 fn flush(&mut self) -> io::Result<()> {
176 Ok(())
178 }
179}
180
181#[derive(Clone, Default, Debug)]
185struct BlockBuffer {
186 bytes: [u8; Self::SIZE],
188
189 position: usize,
191}
192
193impl BlockBuffer {
194 const SIZE: usize = 3;
197
198 fn fill(&mut self, input: &mut &[u8]) -> Result<(), Error> {
200 let remaining = Self::SIZE.checked_sub(self.position).ok_or(InvalidLength)?;
201 let len = cmp::min(input.len(), remaining);
202 self.bytes[self.position..][..len].copy_from_slice(&input[..len]);
203 self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
204 *input = &input[len..];
205 Ok(())
206 }
207
208 fn take(&mut self) -> [u8; Self::SIZE] {
210 debug_assert!(self.is_full());
211 let result = self.bytes;
212 *self = Default::default();
213 result
214 }
215
216 fn is_empty(&self) -> bool {
218 self.position == 0
219 }
220
221 fn is_full(&self) -> bool {
223 self.position == Self::SIZE
224 }
225}
226
227#[derive(Debug)]
229struct LineWrapper {
230 remaining: usize,
232
233 width: usize,
235
236 ending: LineEnding,
238}
239
240impl LineWrapper {
241 fn new(width: usize, ending: LineEnding) -> Result<Self, Error> {
243 if width < MIN_LINE_WIDTH {
244 return Err(InvalidLength);
245 }
246
247 Ok(Self {
248 remaining: width,
249 width,
250 ending,
251 })
252 }
253
254 fn wrap_blocks(&self, blocks: &mut usize) -> Result<(), Error> {
256 if blocks.checked_mul(4).ok_or(InvalidLength)? >= self.remaining {
257 *blocks = self.remaining / 4;
258 }
259
260 Ok(())
261 }
262
263 fn insert_newlines(&mut self, mut buffer: &mut [u8], len: &mut usize) -> Result<(), Error> {
265 let mut buffer_len = *len;
266
267 if buffer_len <= self.remaining {
268 self.remaining = self
269 .remaining
270 .checked_sub(buffer_len)
271 .ok_or(InvalidLength)?;
272
273 return Ok(());
274 }
275
276 buffer = &mut buffer[self.remaining..];
277 buffer_len = buffer_len
278 .checked_sub(self.remaining)
279 .ok_or(InvalidLength)?;
280
281 debug_assert!(buffer_len <= 4, "buffer too long: {}", buffer_len);
283
284 let buffer_end = buffer_len
286 .checked_add(self.ending.len())
287 .ok_or(InvalidLength)?;
288
289 if buffer_end >= buffer.len() {
290 return Err(InvalidLength);
291 }
292
293 for i in (0..buffer_len).rev() {
295 buffer[i.checked_add(self.ending.len()).ok_or(InvalidLength)?] = buffer[i];
296 }
297
298 buffer[..self.ending.len()].copy_from_slice(self.ending.as_bytes());
299 *len = (*len).checked_add(self.ending.len()).ok_or(InvalidLength)?;
300 self.remaining = self.width.checked_sub(buffer_len).ok_or(InvalidLength)?;
301
302 Ok(())
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use crate::{alphabet::Alphabet, test_vectors::*, Base64, Base64Unpadded, Encoder, LineEnding};
309
310 #[test]
311 fn encode_padded() {
312 encode_test::<Base64>(PADDED_BIN, PADDED_BASE64, None);
313 }
314
315 #[test]
316 fn encode_unpadded() {
317 encode_test::<Base64Unpadded>(UNPADDED_BIN, UNPADDED_BASE64, None);
318 }
319
320 #[test]
321 fn encode_multiline_padded() {
322 encode_test::<Base64>(MULTILINE_PADDED_BIN, MULTILINE_PADDED_BASE64, Some(70));
323 }
324
325 #[test]
326 fn encode_multiline_unpadded() {
327 encode_test::<Base64Unpadded>(MULTILINE_UNPADDED_BIN, MULTILINE_UNPADDED_BASE64, Some(70));
328 }
329
330 #[test]
331 fn no_trailing_newline_when_aligned() {
332 let mut buffer = [0u8; 64];
333 let mut encoder = Encoder::<Base64>::new_wrapped(&mut buffer, 64, LineEnding::LF).unwrap();
334 encoder.encode(&[0u8; 48]).unwrap();
335
336 assert_eq!(
338 "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
339 encoder.finish().unwrap()
340 );
341 }
342
343 fn encode_test<V: Alphabet>(input: &[u8], expected: &str, wrapped: Option<usize>) {
345 let mut buffer = [0u8; 1024];
346
347 for chunk_size in 1..input.len() {
348 let mut encoder = match wrapped {
349 Some(line_width) => {
350 Encoder::<V>::new_wrapped(&mut buffer, line_width, LineEnding::LF)
351 }
352 None => Encoder::<V>::new(&mut buffer),
353 }
354 .unwrap();
355
356 for chunk in input.chunks(chunk_size) {
357 encoder.encode(chunk).unwrap();
358 }
359
360 assert_eq!(expected, encoder.finish().unwrap());
361 }
362 }
363}