base64/write/
encoder_string_writer.rs

1use super::encoder::EncoderWriter;
2use crate::engine::Engine;
3use std::io;
4
5/// A `Write` implementation that base64-encodes data using the provided config and accumulates the
6/// resulting base64 utf8 `&str` in a [StrConsumer] implementation (typically `String`), which is
7/// then exposed via `into_inner()`.
8///
9/// # Examples
10///
11/// Buffer base64 in a new String:
12///
13/// ```
14/// use std::io::Write;
15/// use base64::engine::general_purpose;
16///
17/// let mut enc = base64::write::EncoderStringWriter::new(&general_purpose::STANDARD);
18///
19/// enc.write_all(b"asdf").unwrap();
20///
21/// // get the resulting String
22/// let b64_string = enc.into_inner();
23///
24/// assert_eq!("YXNkZg==", &b64_string);
25/// ```
26///
27/// Or, append to an existing `String`, which implements `StrConsumer`:
28///
29/// ```
30/// use std::io::Write;
31/// use base64::engine::general_purpose;
32///
33/// let mut buf = String::from("base64: ");
34///
35/// let mut enc = base64::write::EncoderStringWriter::from_consumer(
36///     &mut buf,
37///     &general_purpose::STANDARD);
38///
39/// enc.write_all(b"asdf").unwrap();
40///
41/// // release the &mut reference on buf
42/// let _ = enc.into_inner();
43///
44/// assert_eq!("base64: YXNkZg==", &buf);
45/// ```
46///
47/// # Performance
48///
49/// Because it has to validate that the base64 is UTF-8, it is about 80% as fast as writing plain
50/// bytes to a `io::Write`.
51pub struct EncoderStringWriter<'e, E: Engine, S: StrConsumer> {
52    encoder: EncoderWriter<'e, E, Utf8SingleCodeUnitWriter<S>>,
53}
54
55impl<'e, E: Engine, S: StrConsumer> EncoderStringWriter<'e, E, S> {
56    /// Create a EncoderStringWriter that will append to the provided `StrConsumer`.
57    pub fn from_consumer(str_consumer: S, engine: &'e E) -> Self {
58        EncoderStringWriter {
59            encoder: EncoderWriter::new(Utf8SingleCodeUnitWriter { str_consumer }, engine),
60        }
61    }
62
63    /// Encode all remaining buffered data, including any trailing incomplete input triples and
64    /// associated padding.
65    ///
66    /// Returns the base64-encoded form of the accumulated written data.
67    pub fn into_inner(mut self) -> S {
68        self.encoder
69            .finish()
70            .expect("Writing to a consumer should never fail")
71            .str_consumer
72    }
73}
74
75impl<'e, E: Engine> EncoderStringWriter<'e, E, String> {
76    /// Create a EncoderStringWriter that will encode into a new `String` with the provided config.
77    pub fn new(engine: &'e E) -> Self {
78        EncoderStringWriter::from_consumer(String::new(), engine)
79    }
80}
81
82impl<'e, E: Engine, S: StrConsumer> io::Write for EncoderStringWriter<'e, E, S> {
83    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
84        self.encoder.write(buf)
85    }
86
87    fn flush(&mut self) -> io::Result<()> {
88        self.encoder.flush()
89    }
90}
91
92/// An abstraction around consuming `str`s produced by base64 encoding.
93pub trait StrConsumer {
94    /// Consume the base64 encoded data in `buf`
95    fn consume(&mut self, buf: &str);
96}
97
98/// As for io::Write, `StrConsumer` is implemented automatically for `&mut S`.
99impl<S: StrConsumer + ?Sized> StrConsumer for &mut S {
100    fn consume(&mut self, buf: &str) {
101        (**self).consume(buf);
102    }
103}
104
105/// Pushes the str onto the end of the String
106impl StrConsumer for String {
107    fn consume(&mut self, buf: &str) {
108        self.push_str(buf);
109    }
110}
111
112/// A `Write` that only can handle bytes that are valid single-byte UTF-8 code units.
113///
114/// This is safe because we only use it when writing base64, which is always valid UTF-8.
115struct Utf8SingleCodeUnitWriter<S: StrConsumer> {
116    str_consumer: S,
117}
118
119impl<S: StrConsumer> io::Write for Utf8SingleCodeUnitWriter<S> {
120    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
121        // Because we expect all input to be valid utf-8 individual bytes, we can encode any buffer
122        // length
123        let s = std::str::from_utf8(buf).expect("Input must be valid UTF-8");
124
125        self.str_consumer.consume(s);
126
127        Ok(buf.len())
128    }
129
130    fn flush(&mut self) -> io::Result<()> {
131        // no op
132        Ok(())
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use crate::{
139        engine::Engine, tests::random_engine, write::encoder_string_writer::EncoderStringWriter,
140    };
141    use rand::Rng;
142    use std::cmp;
143    use std::io::Write;
144
145    #[test]
146    fn every_possible_split_of_input() {
147        let mut rng = rand::thread_rng();
148        let mut orig_data = Vec::<u8>::new();
149        let mut normal_encoded = String::new();
150
151        let size = 5_000;
152
153        for i in 0..size {
154            orig_data.clear();
155            normal_encoded.clear();
156
157            orig_data.resize(size, 0);
158            rng.fill(&mut orig_data[..]);
159
160            let engine = random_engine(&mut rng);
161            engine.encode_string(&orig_data, &mut normal_encoded);
162
163            let mut stream_encoder = EncoderStringWriter::new(&engine);
164            // Write the first i bytes, then the rest
165            stream_encoder.write_all(&orig_data[0..i]).unwrap();
166            stream_encoder.write_all(&orig_data[i..]).unwrap();
167
168            let stream_encoded = stream_encoder.into_inner();
169
170            assert_eq!(normal_encoded, stream_encoded);
171        }
172    }
173    #[test]
174    fn incremental_writes() {
175        let mut rng = rand::thread_rng();
176        let mut orig_data = Vec::<u8>::new();
177        let mut normal_encoded = String::new();
178
179        let size = 5_000;
180
181        for _ in 0..size {
182            orig_data.clear();
183            normal_encoded.clear();
184
185            orig_data.resize(size, 0);
186            rng.fill(&mut orig_data[..]);
187
188            let engine = random_engine(&mut rng);
189            engine.encode_string(&orig_data, &mut normal_encoded);
190
191            let mut stream_encoder = EncoderStringWriter::new(&engine);
192            // write small nibbles of data
193            let mut offset = 0;
194            while offset < size {
195                let nibble_size = cmp::min(rng.gen_range(0..=64), size - offset);
196                let len = stream_encoder
197                    .write(&orig_data[offset..offset + nibble_size])
198                    .unwrap();
199                offset += len;
200            }
201
202            let stream_encoded = stream_encoder.into_inner();
203
204            assert_eq!(normal_encoded, stream_encoded);
205        }
206    }
207}