1use fidl_fuchsia_fxfs::BlobWriterProxy;
6
7use futures::future::{BoxFuture, FutureExt as _};
8use futures::stream::{FuturesOrdered, StreamExt as _, TryStreamExt as _};
9
10mod errors;
11pub use errors::{CreateError, WriteError};
12
13#[derive(Debug)]
16pub struct BlobWriter {
17 blob_writer_proxy: BlobWriterProxy,
18 vmo: zx::Vmo,
19 outstanding_writes:
32 FuturesOrdered<BoxFuture<'static, Result<Result<u64, zx::Status>, fidl::Error>>>,
33 bytes_sent: u64,
35 available: u64,
38 blob_len: u64,
40 vmo_len: u64,
42}
43
44impl BlobWriter {
45 pub async fn create(
47 blob_writer_proxy: BlobWriterProxy,
48 size: u64,
49 ) -> Result<Self, CreateError> {
50 let vmo = blob_writer_proxy
51 .get_vmo(size)
52 .await
53 .map_err(CreateError::Fidl)?
54 .map_err(zx::Status::from_raw)
55 .map_err(CreateError::GetVmo)?;
56 let vmo_len = vmo.get_size().map_err(CreateError::GetSize)?;
57 Ok(BlobWriter {
58 blob_writer_proxy,
59 vmo,
60 outstanding_writes: FuturesOrdered::new(),
61 bytes_sent: 0,
62 available: vmo_len,
63 blob_len: size,
64 vmo_len,
65 })
66 }
67
68 pub async fn write(&mut self, mut bytes: &[u8]) -> Result<(), WriteError> {
80 if self.bytes_sent + bytes.len() as u64 > self.blob_len {
81 return Err(WriteError::EndOfBlob);
82 }
83 while !bytes.is_empty() {
84 debug_assert!(self.outstanding_writes.len() <= 2);
85 if self.available == 0 || self.outstanding_writes.len() == 2 {
87 let bytes_ackd = self
88 .outstanding_writes
89 .next()
90 .await
91 .ok_or_else(|| WriteError::QueueEnded)?
92 .map_err(WriteError::Fidl)?
93 .map_err(WriteError::BytesReady)?;
94 self.available += bytes_ackd;
95 }
96
97 let bytes_to_send_len = {
98 let mut bytes_to_send_len = std::cmp::min(self.available, bytes.len() as u64);
99 if self.blob_len - self.bytes_sent > self.vmo_len {
102 bytes_to_send_len = std::cmp::min(bytes_to_send_len, self.vmo_len / 2)
103 }
104 bytes_to_send_len
105 };
106
107 let (bytes_to_send, remaining_bytes) = bytes.split_at(bytes_to_send_len as usize);
108 bytes = remaining_bytes;
109
110 let vmo_index = self.bytes_sent % self.vmo_len;
111 let (bytes_to_send_before_wrap, bytes_to_send_after_wrap) = bytes_to_send
112 .split_at(std::cmp::min((self.vmo_len - vmo_index) as usize, bytes_to_send.len()));
113
114 self.vmo.write(bytes_to_send_before_wrap, vmo_index).map_err(WriteError::VmoWrite)?;
115 if !bytes_to_send_after_wrap.is_empty() {
116 self.vmo.write(bytes_to_send_after_wrap, 0).map_err(WriteError::VmoWrite)?;
117 }
118
119 let write_fut = self.blob_writer_proxy.bytes_ready(bytes_to_send_len);
120 self.outstanding_writes.push_back(
121 async move {
122 write_fut
123 .await
124 .map(|res| res.map(|()| bytes_to_send_len).map_err(zx::Status::from_raw))
125 }
126 .boxed(),
127 );
128 self.available -= bytes_to_send_len;
129 self.bytes_sent += bytes_to_send_len;
130 }
131 debug_assert!(self.bytes_sent <= self.blob_len);
132
133 if self.bytes_sent == self.blob_len {
135 while let Some(result) =
136 self.outstanding_writes.try_next().await.map_err(WriteError::Fidl)?
137 {
138 match result {
139 Ok(bytes_ackd) => self.available += bytes_ackd,
140 Err(e) => return Err(WriteError::BytesReady(e)),
141 }
142 }
143 if self.available != self.vmo_len {
145 return Err(WriteError::EndOfBlob);
146 }
147 }
148 Ok(())
149 }
150
151 pub fn vmo_size(&self) -> u64 {
152 self.vmo_len
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use assert_matches::assert_matches;
160 use fidl::endpoints::create_proxy_and_stream;
161 use fidl_fuchsia_fxfs::{BlobWriterMarker, BlobWriterRequest};
162 use futures::{pin_mut, select};
163 use rand::{thread_rng, Rng as _};
164 use std::sync::{Arc, Mutex};
165 use zx::HandleBased;
166
167 const VMO_SIZE: usize = 4096;
168
169 async fn check_blob_writer(
170 write_fun: impl FnOnce(BlobWriterProxy) -> BoxFuture<'static, ()>,
171 data: &[u8],
172 writes: &[(usize, usize)],
173 ) {
174 let (proxy, mut stream) = create_proxy_and_stream::<BlobWriterMarker>();
175 let count = Arc::new(Mutex::new(0));
176 let count_clone = count.clone();
177 let expected_count = writes.len();
178 let mut check_vmo = None;
179 let mock_server = async move {
180 while let Some(request) = stream.next().await {
181 match request {
182 Ok(BlobWriterRequest::GetVmo { responder, .. }) => {
183 let vmo = zx::Vmo::create(VMO_SIZE as u64).expect("failed to create vmo");
184 let vmo_dup = vmo
185 .duplicate_handle(zx::Rights::SAME_RIGHTS)
186 .expect("failed to duplicate VMO");
187 check_vmo = Some(vmo);
188 responder.send(Ok(vmo_dup)).unwrap();
189 }
190 Ok(BlobWriterRequest::BytesReady { responder, bytes_written, .. }) => {
191 let vmo = check_vmo.as_ref().unwrap();
192 let mut count_locked = count.lock().unwrap();
193 let mut buf = vec![0; bytes_written as usize];
194 let data_range = writes[*count_locked];
195 let vmo_offset = data_range.0 % VMO_SIZE;
196 if vmo_offset + bytes_written as usize > VMO_SIZE {
197 let split = VMO_SIZE - vmo_offset;
198 vmo.read(&mut buf[0..split], vmo_offset as u64).unwrap();
199 vmo.read(&mut buf[split..], 0).unwrap();
200 } else {
201 vmo.read(&mut buf, vmo_offset as u64).unwrap();
202 }
203 assert_eq!(bytes_written, (data_range.1 - data_range.0) as u64);
204 assert_eq!(&data[data_range.0..data_range.1], buf);
205 *count_locked += 1;
206 responder.send(Ok(())).unwrap();
207 }
208 _ => {
209 unreachable!()
210 }
211 }
212 }
213 }
214 .fuse();
215
216 pin_mut!(mock_server);
217
218 select! {
219 _ = mock_server => unreachable!(),
220 _ = write_fun(proxy).fuse() => {
221 assert_eq!(*count_clone.lock().unwrap(), expected_count);
222 }
223 }
224 }
225
226 #[fuchsia::test]
227 async fn invalid_write_past_end_of_blob() {
228 let mut data = [0; VMO_SIZE];
229 thread_rng().fill(&mut data[..]);
230
231 let write_fun = |proxy: BlobWriterProxy| {
232 async move {
233 let mut blob_writer = BlobWriter::create(proxy, data.len() as u64)
234 .await
235 .expect("failed to create BlobWriter");
236 let () = blob_writer.write(&data).await.unwrap();
237 let invalid_write = [0; 4096];
238 assert_matches!(
239 blob_writer.write(&invalid_write).await,
240 Err(WriteError::EndOfBlob)
241 );
242 }
243 .boxed()
244 };
245
246 check_blob_writer(write_fun, &data, &[(0, VMO_SIZE)]).await;
247 }
248
249 #[fuchsia::test]
250 async fn do_not_split_writes_if_blob_fits_in_vmo() {
251 let mut data = [0; VMO_SIZE - 1];
252 thread_rng().fill(&mut data[..]);
253
254 let write_fun = |proxy: BlobWriterProxy| {
255 async move {
256 let mut blob_writer = BlobWriter::create(proxy, data.len() as u64)
257 .await
258 .expect("failed to create BlobWriter");
259 let () = blob_writer.write(&data[..]).await.unwrap();
260 }
261 .boxed()
262 };
263
264 check_blob_writer(write_fun, &data, &[(0, 4095)]).await;
265 }
266
267 #[fuchsia::test]
268 async fn split_writes_if_blob_does_not_fit_in_vmo() {
269 let mut data = [0; VMO_SIZE + 1];
270 thread_rng().fill(&mut data[..]);
271
272 let write_fun = |proxy: BlobWriterProxy| {
273 async move {
274 let mut blob_writer = BlobWriter::create(proxy, data.len() as u64)
275 .await
276 .expect("failed to create BlobWriter");
277 let () = blob_writer.write(&data[..]).await.unwrap();
278 }
279 .boxed()
280 };
281
282 check_blob_writer(write_fun, &data, &[(0, 2048), (2048, 4096), (4096, 4097)]).await;
283 }
284
285 #[fuchsia::test]
286 async fn third_write_wraps() {
287 let mut data = [0; 1024 * 6];
288 thread_rng().fill(&mut data[..]);
289
290 let writes =
291 [(0, 1024 * 2), (1024 * 2, 1024 * 3), (1024 * 3, 1024 * 5), (1024 * 5, 1024 * 6)];
292
293 let write_fun = |proxy: BlobWriterProxy| {
294 async move {
295 let mut blob_writer = BlobWriter::create(proxy, data.len() as u64)
296 .await
297 .expect("failed to create BlobWriter");
298 for (i, j) in writes {
299 let () = blob_writer.write(&data[i..j]).await.unwrap();
300 }
301 }
302 .boxed()
303 };
304
305 check_blob_writer(write_fun, &data, &writes[..]).await;
306 }
307
308 #[fuchsia::test]
309 async fn many_wraps() {
310 let mut data = [0; VMO_SIZE * 3];
311 thread_rng().fill(&mut data[..]);
312
313 let write_fun = |proxy: BlobWriterProxy| {
314 async move {
315 let mut blob_writer = BlobWriter::create(proxy, data.len() as u64)
316 .await
317 .expect("failed to create BlobWriter");
318 let () = blob_writer.write(&data[0..1]).await.unwrap();
319 let () = blob_writer.write(&data[1..]).await.unwrap();
320 }
321 .boxed()
322 };
323
324 check_blob_writer(
325 write_fun,
326 &data,
327 &[
328 (0, 1),
329 (1, 2049),
330 (2049, 4097),
331 (4097, 6145),
332 (6145, 8193),
333 (8193, 10241),
334 (10241, 12288),
335 ],
336 )
337 .await;
338 }
339}