1use fidl::client::QueryResponseFut;
6use flex_fuchsia_io as fio;
7use futures::io::AsyncRead;
8use std::cmp::min;
9use std::convert::TryInto as _;
10use std::future::Future as _;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use zx_status_ext::StatusExt;
14
15use flex_client::fidl::Proxy as _;
16
17#[derive(Debug)]
20pub struct AsyncReader {
21 file: fio::FileProxy,
22 state: State,
23}
24
25#[derive(Debug)]
26enum State {
27 Empty,
28 Forwarding {
29 fut: QueryResponseFut<Result<Vec<u8>, i32>, flex_client::Dialect>,
30 zero_byte_request: bool,
31 },
32 Bytes {
33 bytes: Vec<u8>,
34 offset: usize,
35 },
36}
37
38impl AsyncReader {
39 pub fn from_proxy(file: fio::FileProxy) -> Result<Self, AsyncReaderError> {
51 let file = match file.into_channel() {
52 Ok(channel) => fio::FileProxy::new(channel),
53 Err(file) => {
54 return Err(AsyncReaderError::NonExclusiveChannelOwnership(file));
55 }
56 };
57 Ok(Self { file, state: State::Empty })
58 }
59}
60
61impl AsyncRead for AsyncReader {
62 fn poll_read(
63 mut self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &mut [u8],
66 ) -> Poll<std::io::Result<usize>> {
67 loop {
68 match self.state {
69 State::Empty => {
70 let len = if let Ok(len) = buf.len().try_into() {
71 min(len, fio::MAX_BUF)
72 } else {
73 fio::MAX_BUF
74 };
75 self.state =
76 State::Forwarding { fut: self.file.read(len), zero_byte_request: len == 0 };
77 }
78 State::Forwarding { ref mut fut, ref zero_byte_request } => {
79 match futures::ready!(Pin::new(fut).poll(cx)) {
80 Ok(result) => {
81 match result {
82 Err(s) => {
83 self.state = State::Empty;
84 return Poll::Ready(Err(
85 zx_status::Status::from_raw(s).into_io_error()
86 ));
87 }
88 Ok(bytes) => {
89 if *zero_byte_request && buf.len() != 0 {
98 self.state = State::Empty;
99 } else {
100 self.state = State::Bytes { bytes, offset: 0 };
101 }
102 }
103 }
104 }
105 Err(e) => {
106 self.state = State::Empty;
107 return Poll::Ready(Err(std::io::Error::other(e)));
108 }
109 }
110 }
111 State::Bytes { ref bytes, ref mut offset } => {
112 let n = min(buf.len(), bytes.len() - *offset);
113 let next_offset = *offset + n;
114 let () = buf[..n].copy_from_slice(&bytes[*offset..next_offset]);
115 if next_offset == bytes.len() {
116 self.state = State::Empty;
117 } else {
118 *offset = next_offset;
119 }
120 return Poll::Ready(Ok(n));
121 }
122 }
123 }
124 }
125}
126
127#[derive(Debug, thiserror::Error)]
128pub enum AsyncReaderError {
129 #[error("Supplied FileProxy did not have exclusive ownership of the underlying channel")]
130 NonExclusiveChannelOwnership(fio::FileProxy),
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::file;
137 use assert_matches::assert_matches;
138 use fidl::endpoints;
139 use fuchsia_async as fasync;
140 use futures::future::poll_fn;
141 use futures::io::AsyncReadExt as _;
142 use futures::{StreamExt as _, TryStreamExt as _, join};
143 use std::convert::TryFrom as _;
144 use tempfile::TempDir;
145
146 #[fasync::run_singlethreaded(test)]
147 async fn exclusive_ownership() {
148 let (proxy, _) = endpoints::create_proxy::<fio::FileMarker>();
149 let _stream = proxy.take_event_stream();
150
151 assert_matches!(AsyncReader::from_proxy(proxy), Err(_));
152 }
153
154 async fn read_to_end_file_with_expected_contents(expected_contents: &[u8]) {
155 let dir = TempDir::new().unwrap();
156 let path =
157 dir.path().join("read_to_end_with_expected_contents").to_str().unwrap().to_owned();
158 let () = file::write_in_namespace(&path, expected_contents).await.unwrap();
159 let file = file::open_in_namespace(&path, fio::PERM_READABLE).unwrap();
160
161 let mut reader = AsyncReader::from_proxy(file).unwrap();
162 let mut actual_contents = vec![];
163 reader.read_to_end(&mut actual_contents).await.unwrap();
164
165 assert_eq!(actual_contents, expected_contents);
166 }
167
168 #[fasync::run_singlethreaded(test)]
169 async fn read_to_end_empty() {
170 read_to_end_file_with_expected_contents(&[]).await;
171 }
172
173 #[fasync::run_singlethreaded(test)]
174 async fn read_to_end_large() {
175 let expected_contents = vec![7u8; (fio::MAX_BUF * 3).try_into().unwrap()];
176 read_to_end_file_with_expected_contents(&expected_contents[..]).await;
177 }
178
179 async fn poll_read_with_specific_buf_size(poll_read_size: u64, expected_file_read_size: u64) {
180 let (proxy, mut stream) = endpoints::create_proxy_and_stream::<fio::FileMarker>();
181
182 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
183
184 let () = poll_fn(|cx| {
185 let mut buf = vec![0u8; poll_read_size.try_into().unwrap()];
186 assert_matches!(Pin::new(&mut reader).poll_read(cx, buf.as_mut_slice()), Poll::Pending);
187 Poll::Ready(())
188 })
189 .await;
190
191 match stream.next().await.unwrap().unwrap() {
192 fio::FileRequest::Read { count, .. } => {
193 assert_eq!(count, expected_file_read_size);
194 }
195 req => panic!("unhandled request {:?}", req),
196 }
197 }
198
199 #[fasync::run_singlethreaded(test)]
200 async fn poll_read_empty_buf() {
201 poll_read_with_specific_buf_size(0, 0).await;
202 }
203
204 #[fasync::run_singlethreaded(test)]
205 async fn poll_read_caps_buf_size() {
206 poll_read_with_specific_buf_size(fio::MAX_BUF * 2, fio::MAX_BUF).await;
207 }
208
209 #[fasync::run_singlethreaded(test)]
210 async fn poll_read_pending_saves_future() {
211 let (proxy, mut stream) = endpoints::create_proxy_and_stream::<fio::FileMarker>();
212
213 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
214
215 let () = poll_fn(|cx| {
219 assert_matches!(Pin::new(&mut reader).poll_read(cx, &mut [0u8; 1]), Poll::Pending);
220 Poll::Ready(())
221 })
222 .await;
223
224 let poll_read = async move {
227 let mut buf = [0u8; 1];
228 assert_eq!(reader.read(&mut buf).await.unwrap(), buf.len());
229 assert_eq!(&buf, &[1]);
230 };
231
232 let mut file_read_requests = 0u8;
233 let handle_file_stream = async {
234 while let Some(req) = stream.try_next().await.unwrap() {
235 file_read_requests += 1;
236 match req {
237 fio::FileRequest::Read { count, responder } => {
238 assert_eq!(count, 1);
239 responder.send(Ok(&[file_read_requests])).unwrap();
240 }
241 req => panic!("unhandled request {:?}", req),
242 }
243 }
244 };
245
246 let ((), ()) = join!(poll_read, handle_file_stream);
247 assert_eq!(file_read_requests, 1);
248 }
249
250 #[fasync::run_singlethreaded(test)]
251 async fn poll_read_with_smaller_buf_after_pending() {
252 let (proxy, mut stream) = endpoints::create_proxy_and_stream::<fio::FileMarker>();
253
254 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
255
256 let () = poll_fn(|cx| {
260 assert_matches!(Pin::new(&mut reader).poll_read(cx, &mut [0u8; 3]), Poll::Pending);
261 Poll::Ready(())
262 })
263 .await;
264
265 let () = async {
267 match stream.next().await.unwrap().unwrap() {
268 fio::FileRequest::Read { count, responder } => {
269 assert_eq!(count, 3);
270 responder.send(Ok(b"012")).unwrap();
271 }
272 req => panic!("unhandled request {:?}", req),
273 }
274 }
275 .await;
276
277 let mut buf = [0u8; 1];
280 assert_eq!(reader.read(&mut buf).await.unwrap(), buf.len());
281 assert_eq!(&buf, b"0");
282
283 let mut buf = [0u8; 1];
286 assert_eq!(reader.read(&mut buf).await.unwrap(), buf.len());
287 assert_eq!(&buf, b"1");
288
289 let mut buf = [0u8; 2];
292 assert_eq!(reader.read(&mut buf).await.unwrap(), 1);
293 assert_eq!(&buf[..1], b"2");
294
295 let mut buf = [0u8; 4];
298 let poll_read = reader.read(&mut buf);
299
300 let handle_second_file_request = async {
301 match stream.next().await.unwrap().unwrap() {
302 fio::FileRequest::Read { count, responder } => {
303 assert_eq!(count, 4);
304 responder.send(Ok(b"3456")).unwrap();
305 }
306 req => panic!("unhandled request {:?}", req),
307 }
308 };
309
310 let (read_res, ()) = join!(poll_read, handle_second_file_request);
311 assert_eq!(read_res.unwrap(), 4);
312 assert_eq!(&buf, b"3456");
313 }
314
315 #[fasync::run_singlethreaded(test)]
316 async fn transition_to_empty_on_fidl_error() {
317 let (proxy, _) = endpoints::create_proxy_and_stream::<fio::FileMarker>();
318
319 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
320
321 let () = poll_fn(|cx| {
323 assert_matches!(
324 Pin::new(&mut reader).poll_read(cx, &mut [0u8; 1]),
325 Poll::Ready(Err(_))
326 );
327 Poll::Ready(())
328 })
329 .await;
330
331 assert_matches!(reader.state, State::Empty);
336 }
337
338 #[fasync::run_singlethreaded(test)]
339 async fn recover_from_file_read_error() {
340 let (proxy, mut stream) = endpoints::create_proxy_and_stream::<fio::FileMarker>();
341
342 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
343
344 let mut buf = [0u8; 1];
346 let poll_read = reader.read(&mut buf);
347
348 let failing_file_response = async {
349 match stream.next().await.unwrap().unwrap() {
350 fio::FileRequest::Read { count, responder } => {
351 assert_eq!(count, 1);
352 responder.send(Err(zx_status::Status::NO_MEMORY.into_raw())).unwrap();
353 }
354 req => panic!("unhandled request {:?}", req),
355 }
356 };
357
358 let (read_res, ()) = join!(poll_read, failing_file_response);
359 assert_matches!(read_res, Err(_));
360
361 let mut buf = [0u8; 1];
364 let poll_read = reader.read(&mut buf);
365
366 let succeeding_file_response = async {
367 match stream.next().await.unwrap().unwrap() {
368 fio::FileRequest::Read { count, responder } => {
369 assert_eq!(count, 1);
370 responder.send(Ok(b"0")).unwrap();
371 }
372 req => panic!("unhandled request {:?}", req),
373 }
374 };
375
376 let (read_res, ()) = join!(poll_read, succeeding_file_response);
377 assert_eq!(read_res.unwrap(), 1);
378 assert_eq!(&buf, b"0");
379 }
380
381 #[fasync::run_singlethreaded(test)]
382 async fn poll_read_zero_then_read_nonzero() {
383 let (proxy, mut stream) = endpoints::create_proxy_and_stream::<fio::FileMarker>();
384
385 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
386
387 let () = poll_fn(|cx| {
389 assert_matches!(Pin::new(&mut reader).poll_read(cx, &mut []), Poll::Pending);
390 Poll::Ready(())
391 })
392 .await;
393
394 match stream.next().await.unwrap().unwrap() {
396 fio::FileRequest::Read { count, responder } => {
397 assert_eq!(count, 0);
398 responder.send(Ok(&[])).unwrap();
399 }
400 req => panic!("unhandled request {:?}", req),
401 }
402
403 let mut buf = vec![0u8; 1];
405 let poll_read = reader.read(&mut buf);
406
407 let handle_file_request = async {
412 match stream.next().await.unwrap().unwrap() {
413 fio::FileRequest::Read { count, responder } => {
414 assert_eq!(count, 1);
415 responder.send(Ok(&[1])).unwrap();
416 }
417 req => panic!("unhandled request {:?}", req),
418 }
419 };
420
421 let (poll_read, ()) = join!(poll_read, handle_file_request);
422
423 assert_eq!(poll_read.unwrap(), 1);
426 assert_eq!(&buf[..], &[1]);
427 }
428
429 #[fasync::run_singlethreaded(test)]
430 async fn different_poll_read_and_file_sizes() {
431 for first_poll_read_len in 0..5 {
432 for file_size in 0..5 {
433 for second_poll_read_len in 0..5 {
434 let (proxy, mut stream) =
435 endpoints::create_proxy_and_stream::<fio::FileMarker>();
436
437 let mut reader = AsyncReader::from_proxy(proxy).unwrap();
438
439 let () = poll_fn(|cx| {
441 let mut buf = vec![0u8; first_poll_read_len];
442 assert_matches!(
443 Pin::new(&mut reader).poll_read(cx, &mut buf),
444 Poll::Pending
445 );
446 Poll::Ready(())
447 })
448 .await;
449
450 match stream.next().await.unwrap().unwrap() {
453 fio::FileRequest::Read { count, responder } => {
454 assert_eq!(count, u64::try_from(first_poll_read_len).unwrap());
455 let resp = vec![7u8; min(file_size, first_poll_read_len)];
456 responder.send(Ok(&resp)).unwrap();
457 }
458 req => panic!("unhandled request {:?}", req),
459 }
460
461 let mut buf = vec![0u8; second_poll_read_len];
465 let poll_read = reader.read(&mut buf);
466
467 let handle_conditional_file_request = async {
468 if first_poll_read_len == 0 && second_poll_read_len != 0 {
469 match stream.next().await.unwrap().unwrap() {
470 fio::FileRequest::Read { count, responder } => {
471 assert_eq!(count, u64::try_from(second_poll_read_len).unwrap());
472 let resp = vec![7u8; min(file_size, second_poll_read_len)];
473 responder.send(Ok(&resp)).unwrap();
474 }
475 req => panic!("unhandled request {:?}", req),
476 }
477 }
478 };
479
480 let (read_res, ()) = join!(poll_read, handle_conditional_file_request);
481
482 let expected_len = if first_poll_read_len == 0 {
483 min(file_size, second_poll_read_len)
484 } else {
485 min(first_poll_read_len, min(file_size, second_poll_read_len))
486 };
487 let expected = vec![7u8; expected_len];
488 assert_eq!(read_res.unwrap(), expected_len);
489 assert_eq!(&buf[..expected_len], &expected[..]);
490 }
491 }
492 }
493 }
494}