payload_streamer/
lib.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use anyhow::{Context as _, Error};
6use async_trait::async_trait;
7use block_client::{BlockClient, MutableBufferSlice, RemoteBlockClient, VmoId};
8use fidl_fuchsia_hardware_block::BlockProxy;
9use fidl_fuchsia_paver::{PayloadStreamRequest, PayloadStreamRequestStream, ReadInfo, ReadResult};
10
11use futures::lock::Mutex;
12use futures::prelude::*;
13use mapped_vmo::Mapping;
14use std::io::Read;
15
16/// Callback type, called with (data_read, data_total)
17pub trait StatusCallback: Send + Sync + Fn(usize, usize) -> () {}
18impl<F> StatusCallback for F where F: Send + Sync + Fn(usize, usize) -> () {}
19
20#[async_trait]
21pub trait PayloadStreamer {
22    /// Handle the server side of the PayloadStream service.
23    async fn service_payload_stream_requests(
24        self: Box<Self>,
25        stream: PayloadStreamRequestStream,
26        status_callback: Option<&dyn StatusCallback>,
27    ) -> Result<(), Error>;
28}
29
30struct ReaderPayloadStreamerInner {
31    src: Box<dyn Read + Sync + Send>,
32    src_read: usize,           // Read offset into the reader.
33    src_size: usize,           // Size of the reader.
34    dest_buf: Option<Mapping>, // Maps the VMO used for the PayloadStream protocol.
35    dest_size: usize,          // Size of the VMO used for the PayloadStream protocol.
36}
37
38/// Streams the contents of a reader over the PayloadStream protocol.
39pub struct ReaderPayloadStreamer {
40    // We wrap all our state inside a mutex, to make it mutable.
41    inner: Mutex<ReaderPayloadStreamerInner>,
42}
43
44impl ReaderPayloadStreamer {
45    pub fn new(src: Box<dyn Read + Sync + Send>, src_size: usize) -> Self {
46        ReaderPayloadStreamer {
47            inner: Mutex::new(ReaderPayloadStreamerInner {
48                src,
49                src_read: 0,
50                src_size,
51                dest_buf: None,
52                dest_size: 0,
53            }),
54        }
55    }
56
57    /// Handle a single request from a FIDL client.
58    async fn handle_request(
59        self: &Self,
60        req: PayloadStreamRequest,
61        status_callback: Option<&dyn StatusCallback>,
62    ) -> Result<(), Error> {
63        let mut unwrapped = self.inner.lock().await;
64        match req {
65            PayloadStreamRequest::RegisterVmo { vmo, responder } => {
66                // Make sure we only get bound once.
67                if unwrapped.dest_buf.is_some() {
68                    responder.send(zx::sys::ZX_ERR_ALREADY_BOUND)?;
69                    return Ok(());
70                }
71
72                // Figure out information about the new VMO.
73                let size = vmo.get_size();
74                if let Err(e) = size {
75                    responder.send(e.into_raw())?;
76                    return Ok(());
77                }
78
79                let size = size.unwrap() as usize;
80
81                let mapping = Mapping::create_from_vmo(
82                    &vmo,
83                    size,
84                    zx::VmarFlags::PERM_READ | zx::VmarFlags::PERM_WRITE,
85                );
86
87                if let Err(e) = mapping {
88                    responder.send(e.into_raw())?;
89                    return Ok(());
90                }
91
92                unwrapped.dest_buf = Some(mapping.unwrap());
93                unwrapped.dest_size = size;
94                responder.send(zx::sys::ZX_OK)?;
95            }
96            PayloadStreamRequest::ReadData { responder } => {
97                if unwrapped.dest_buf.is_none() || unwrapped.dest_size == 0 {
98                    responder.send(&ReadResult::Err { 0: zx::sys::ZX_ERR_BAD_STATE })?;
99                    return Ok(());
100                }
101
102                let data_left = unwrapped.src_size - unwrapped.src_read;
103                let data_to_read = std::cmp::min(data_left, unwrapped.dest_size);
104                let mut buf: Vec<u8> = vec![0; data_to_read];
105                let read = unwrapped.src.read(&mut buf);
106                if let Err(e) = read {
107                    responder.send(&ReadResult::Err {
108                        0: e.raw_os_error().unwrap_or(zx::sys::ZX_ERR_INTERNAL),
109                    })?;
110                    return Ok(());
111                }
112                let read = read?;
113                if read == 0 {
114                    responder.send(&ReadResult::Eof { 0: true })?;
115                    return Ok(());
116                }
117
118                unwrapped.dest_buf.as_ref().unwrap().write(&buf);
119
120                unwrapped.src_read += read;
121                responder.send(&ReadResult::Info {
122                    0: ReadInfo { offset: 0, size: data_to_read as u64 },
123                })?;
124
125                let src_read = unwrapped.src_read;
126                let src_size = unwrapped.src_size;
127                if let Some(cb) = status_callback {
128                    cb(src_read, src_size);
129                }
130            }
131        }
132        return Ok(());
133    }
134}
135
136#[async_trait]
137impl PayloadStreamer for ReaderPayloadStreamer {
138    async fn service_payload_stream_requests(
139        self: Box<Self>,
140        stream: PayloadStreamRequestStream,
141        status_callback: Option<&dyn StatusCallback>,
142    ) -> Result<(), Error> {
143        stream
144            .map(|result| result.context("failed request"))
145            .try_for_each(|request| async { self.handle_request(request, status_callback).await })
146            .await
147    }
148}
149
150struct BlockDevicePayloadStreamerInner {
151    device: RemoteBlockClient,
152    device_read: usize,        // Read offset into the block device.
153    device_size: usize,        // Size of the block device.
154    device_vmo_id: VmoId,      // VMO id used to read from the RemoteBlockClient.
155    device_buf: Mapping, // Maps the VMO the RemoteBlockClient uses to read from the block device.
156    device_vmo_read: usize, // Read offset into the VMO used to read from the block device.
157    dest_buf: Option<Mapping>, // Maps the VMO used for the PayloadStream protocol.
158    dest_size: usize,    // Size of the VMO used for the PayloadStream protocol.
159}
160
161/// Streams the contents of a block device over the PayloadStream protocol.
162pub struct BlockDevicePayloadStreamer {
163    // We wrap all our state inside a mutex, to make it mutable.
164    inner: Mutex<BlockDevicePayloadStreamerInner>,
165}
166
167//TODO(https://fxbug.dev/42059224): Increasing this may speed up the transfer once the UMS crash is fixed.
168const DEVICE_VMO_SIZE: usize = 8192 * 16;
169
170impl BlockDevicePayloadStreamer {
171    pub async fn new(block_device: BlockProxy) -> Result<Self, Error> {
172        let client = RemoteBlockClient::new(block_device).await?;
173
174        let device_vmo = zx::Vmo::create(DEVICE_VMO_SIZE as u64)?;
175        let device_vmo_id = client.attach_vmo(&device_vmo).await?;
176        let device_buf = Mapping::create_from_vmo(
177            &device_vmo,
178            DEVICE_VMO_SIZE,
179            zx::VmarFlags::PERM_READ | zx::VmarFlags::PERM_WRITE,
180        )?;
181
182        Ok(BlockDevicePayloadStreamer {
183            inner: Mutex::new(BlockDevicePayloadStreamerInner {
184                device_size: client.block_size() as usize * client.block_count() as usize,
185                device: client,
186                device_read: 0,
187                device_vmo_id,
188                device_buf,
189                device_vmo_read: 0,
190                dest_buf: None,
191                dest_size: 0,
192            }),
193        })
194    }
195
196    /// Handle a single request from a FIDL client.
197    async fn handle_request(
198        &self,
199        req: PayloadStreamRequest,
200        status_callback: Option<&dyn StatusCallback>,
201    ) -> Result<(), Error> {
202        let mut unwrapped = self.inner.lock().await;
203        match req {
204            PayloadStreamRequest::RegisterVmo { vmo, responder } => {
205                // Make sure we only get bound once.
206                if unwrapped.dest_buf.is_some() {
207                    responder.send(zx::sys::ZX_ERR_ALREADY_BOUND)?;
208                    return Ok(());
209                }
210
211                // Figure out information about the new VMO.
212                let size = vmo.get_size();
213                if let Err(e) = size {
214                    responder.send(e.into_raw())?;
215                    return Ok(());
216                }
217
218                let size = size.unwrap() as usize;
219                // Simplified logic if the size of the VMO used to read from the device is some
220                // multiple of the size of the VMO used for the PayloadStream protocol.
221                assert_eq!(DEVICE_VMO_SIZE % size, 0);
222
223                let mapping = Mapping::create_from_vmo(
224                    &vmo,
225                    size,
226                    zx::VmarFlags::PERM_READ | zx::VmarFlags::PERM_WRITE,
227                );
228
229                if let Err(e) = mapping {
230                    responder.send(e.into_raw())?;
231                    return Ok(());
232                }
233
234                unwrapped.dest_buf = Some(mapping.unwrap());
235                unwrapped.dest_size = size;
236
237                responder.send(zx::sys::ZX_OK)?;
238            }
239            PayloadStreamRequest::ReadData { responder } => {
240                if unwrapped.dest_buf.is_none() || unwrapped.dest_size == 0 {
241                    responder.send(&ReadResult::Err { 0: zx::sys::ZX_ERR_BAD_STATE })?;
242                    return Ok(());
243                }
244
245                let data_left = unwrapped.device_size - unwrapped.device_read;
246
247                if data_left == 0 {
248                    responder.send(&ReadResult::Eof { 0: true })?;
249                    return Ok(());
250                }
251
252                // Check if we need to read more data from the block device.
253                // We read more than `dest_size` bytes from the block device at a time for better
254                // throughput.
255                if unwrapped.device_read == 0 || unwrapped.device_vmo_read == DEVICE_VMO_SIZE {
256                    let data_to_read = std::cmp::min(data_left, DEVICE_VMO_SIZE);
257                    let buffer_slice = MutableBufferSlice::new_with_vmo_id(
258                        &unwrapped.device_vmo_id,
259                        0,
260                        data_to_read as u64,
261                    );
262
263                    if let Err(e) =
264                        unwrapped.device.read_at(buffer_slice, unwrapped.device_read as u64).await
265                    {
266                        responder.send(&ReadResult::Err { 0: e.into_raw() })?;
267                        return Ok(());
268                    }
269                    unwrapped.device_vmo_read = 0;
270                }
271
272                let data_to_return = std::cmp::min(data_left, unwrapped.dest_size);
273
274                // Copy data from the device VMO to the PayloadStream VMO.
275                // Avoiding the double copy here doesn't speed up the stream significantly.
276                let mut buf: Vec<u8> = vec![0; data_to_return];
277                unwrapped.device_buf.read_at(unwrapped.device_vmo_read, &mut buf);
278                unwrapped.dest_buf.as_ref().unwrap().write(&buf);
279
280                unwrapped.device_vmo_read += data_to_return;
281                unwrapped.device_read += data_to_return;
282
283                responder.send(&ReadResult::Info {
284                    0: ReadInfo { offset: 0, size: data_to_return as u64 },
285                })?;
286
287                let device_read = unwrapped.device_read;
288                let device_size = unwrapped.device_size;
289                if let Some(cb) = status_callback {
290                    cb(device_read, device_size);
291                }
292            }
293        }
294        return Ok(());
295    }
296
297    async fn close(&self) -> Result<(), Error> {
298        let unwrapped = self.inner.lock().await;
299        unwrapped.device.detach_vmo(unwrapped.device_vmo_id.take()).await?;
300        Ok(unwrapped.device.close().await?)
301    }
302}
303
304#[async_trait]
305impl PayloadStreamer for BlockDevicePayloadStreamer {
306    async fn service_payload_stream_requests(
307        self: Box<Self>,
308        stream: PayloadStreamRequestStream,
309        status_callback: Option<&dyn StatusCallback>,
310    ) -> Result<(), Error> {
311        let result = stream
312            .map(|result| result.context("failed request"))
313            .try_for_each(|request| async { self.handle_request(request, status_callback).await })
314            .await;
315
316        if let Err(e) = result {
317            // Still attempt to close the client but ignore any errors.
318            self.close().await.ok();
319            return Err(e);
320        }
321
322        self.close().await
323    }
324}
325
326#[cfg(test)]
327mod tests {
328
329    use super::*;
330    use anyhow::{anyhow, Context};
331    use fidl_fuchsia_hardware_block::BlockMarker;
332    use fidl_fuchsia_paver::{PayloadStreamMarker, PayloadStreamProxy};
333    use fuchsia_async as fasync;
334    use futures::future::try_join;
335    use ramdevice_client::{RamdiskClient, RamdiskClientBuilder};
336    use std::io::Cursor;
337    use std::sync::{Arc, Mutex};
338    use zx::{self as zx, HandleBased};
339
340    struct StatusUpdate {
341        data_read: usize,
342        data_size: usize,
343    }
344
345    impl StatusUpdate {
346        fn status_callback(&mut self, data_read: usize, data_size: usize) {
347            self.data_read = data_read;
348            self.data_size = data_size;
349        }
350    }
351
352    async fn serve_payload<'a>(
353        streamer: Box<dyn PayloadStreamer>,
354        status_callback: Option<&'a dyn StatusCallback>,
355    ) -> Result<(PayloadStreamProxy, impl Future<Output = Result<(), Error>> + 'a), Error> {
356        let (client_end, server_end) = fidl::endpoints::create_endpoints::<PayloadStreamMarker>();
357        let stream = server_end.into_stream();
358
359        // Do not await as we return this Future so that the caller can run the client and server
360        // concurrently.
361        let server = streamer.service_payload_stream_requests(stream, status_callback);
362
363        return Ok((client_end.into_proxy(), server));
364    }
365
366    async fn create_ramdisk(src: Vec<u8>) -> Result<RamdiskClient, Error> {
367        let vmo = zx::Vmo::create(src.len() as u64).context("failed to create vmo")?;
368        vmo.write(&src, 0).context("failed to write vmo")?;
369        RamdiskClientBuilder::new_with_vmo(vmo, None)
370            .build()
371            .await
372            .context("failed to create ramdisk client")
373    }
374
375    async fn attach_vmo(
376        vmo_size: usize,
377        proxy: &PayloadStreamProxy,
378    ) -> Result<(i32, Option<zx::Vmo>), anyhow::Error> {
379        let local_vmo = zx::Vmo::create(vmo_size as u64)?;
380        let registered_vmo = local_vmo.duplicate_handle(zx::Rights::SAME_RIGHTS)?;
381        let ret = proxy.register_vmo(registered_vmo).await?;
382        if ret != zx::Status::OK.into_raw() {
383            Ok((ret, None))
384        } else {
385            Ok((zx::Status::OK.into_raw(), Some(local_vmo)))
386        }
387    }
388
389    async fn read_slice(
390        vmo: &zx::Vmo,
391        vmo_size: usize,
392        proxy: &PayloadStreamProxy,
393        byte: u8,
394        mut read: usize,
395    ) -> Result<usize, Error> {
396        let ret = proxy.read_data().await?;
397        match ret {
398            ReadResult::Err { 0: err } => {
399                panic!("read_data failed: {}", err);
400            }
401            ReadResult::Eof { 0: boolean } => {
402                panic!("unexpected eof: {}", boolean);
403            }
404
405            ReadResult::Info { 0: info } => {
406                let mut written_buf: Vec<u8> = vec![0; vmo_size];
407                let slice = &mut written_buf[0..info.size as usize];
408                vmo.read(slice, info.offset)?;
409                for (i, val) in slice.iter().enumerate() {
410                    assert_eq!(*val, byte, "byte {} was wrong", i + read);
411                }
412                read += info.size as usize;
413            }
414        }
415
416        Ok(read)
417    }
418
419    async fn expect_eof(proxy: &PayloadStreamProxy) -> Result<(), Error> {
420        let ret = proxy.read_data().await?;
421        if let ReadResult::Eof { 0: _ } = ret {
422            return Ok(());
423        } else {
424            panic!("Should be at EOF but not at EOF!");
425        }
426    }
427
428    async fn run_client(
429        proxy: PayloadStreamProxy,
430        src_size: usize,
431        dst_size: usize,
432        byte: u8,
433        callback_status: Arc<Mutex<StatusUpdate>>,
434    ) -> Result<(), Error> {
435        let buf: Vec<u8> = vec![byte; src_size];
436        let vmo = attach_vmo(dst_size, &proxy).await?.1.expect("No vmo");
437        let mut read = 0;
438        while read < buf.len() {
439            read = read_slice(&vmo, dst_size, &proxy, byte, read).await?;
440            let data = callback_status.lock().unwrap();
441            assert_eq!(data.data_size, src_size);
442            assert_eq!(data.data_read, read);
443        }
444
445        expect_eof(&proxy).await
446    }
447
448    async fn do_one_test(
449        src_size: usize,
450        dst_size: usize,
451        byte: u8,
452        use_block_device_streamer: bool,
453    ) -> Result<(), Error> {
454        let buf: Vec<u8> = vec![byte; src_size];
455
456        // Extend the ramdisk client's scope.
457        let ramdisk_client: RamdiskClient;
458
459        let streamer: Box<dyn PayloadStreamer> = if use_block_device_streamer {
460            ramdisk_client = create_ramdisk(buf).await?;
461            // TODO(https://fxbug.dev/42063787): Once ramdisk.open() no longer provides a
462            // multiplexing channel, use open() to acquire the BlockProxy here.
463            let ramdisk_controller = ramdisk_client
464                .as_controller()
465                .ok_or_else(|| anyhow!("invalid ramdisk controller"))?;
466            let (ramdisk_block, server) = fidl::endpoints::create_proxy::<BlockMarker>();
467            let () = ramdisk_controller.connect_to_device_fidl(server.into_channel())?;
468            let payload_streamer = BlockDevicePayloadStreamer::new(ramdisk_block).await?;
469            Box::new(payload_streamer)
470        } else {
471            Box::new(ReaderPayloadStreamer::new(Box::new(Cursor::new(buf)), src_size))
472        };
473
474        let status_update = Arc::new(Mutex::new(StatusUpdate { data_read: 0, data_size: 0 }));
475        let status_callback = |data_read, data_size| {
476            status_update.lock().unwrap().status_callback(data_read, data_size)
477        };
478        let (proxy, server) = serve_payload(streamer, Some(&status_callback))
479            .await
480            .context("serve payload failed")?;
481        try_join(server, run_client(proxy, src_size, dst_size, byte, status_update.clone()))
482            .await?;
483
484        Ok(())
485    }
486
487    #[fasync::run_singlethreaded(test)]
488    async fn test_stream_simple() -> Result<(), Error> {
489        do_one_test(200, 200, 0xaa, false).await
490    }
491
492    #[fasync::run_singlethreaded(test)]
493    async fn test_large_src_buffer() -> Result<(), Error> {
494        do_one_test(4096 * 10, 4096, 0x76, false).await
495    }
496
497    #[fasync::run_singlethreaded(test)]
498    async fn test_large_dst_buffer() -> Result<(), Error> {
499        do_one_test(4096, 4096 * 10, 0x76, false).await
500    }
501
502    #[fasync::run_singlethreaded(test)]
503    async fn test_large_buffers() -> Result<(), Error> {
504        do_one_test(4096 * 100, 4096 * 100, 0xfa, false).await
505    }
506
507    #[fasync::run_singlethreaded(test)]
508    async fn test_multiple_registers() -> Result<(), Error> {
509        let src_size = 4096 * 10;
510        let dst_size = 4096;
511        let byte: u8 = 0xab;
512        let buf: Vec<u8> = vec![byte; src_size];
513        let streamer: Box<dyn PayloadStreamer> =
514            Box::new(ReaderPayloadStreamer::new(Box::new(Cursor::new(buf)), src_size));
515        let (proxy, server) = serve_payload(streamer, None).await?;
516
517        try_join(
518            async move {
519                let (_, vmo) = attach_vmo(dst_size, &proxy).await?;
520                assert!(vmo.is_some());
521                let (err, _) = attach_vmo(dst_size, &proxy).await?;
522                assert_eq!(err, zx::sys::ZX_ERR_ALREADY_BOUND);
523                Ok(())
524            },
525            server,
526        )
527        .await?;
528
529        Ok(())
530    }
531
532    #[fasync::run_singlethreaded(test)]
533    async fn test_block_streamer_simple() -> Result<(), Error> {
534        do_one_test(4096, 8192, 0xaa, true).await
535    }
536
537    #[fasync::run_singlethreaded(test)]
538    async fn test_block_streamer_large_src_buffer() -> Result<(), Error> {
539        do_one_test(4096 * 100, 8192, 0x76, true).await
540    }
541
542    #[fasync::run_singlethreaded(test)]
543    async fn test_block_streamer_large_dst_buffer() -> Result<(), Error> {
544        do_one_test(4096, 8192 * 16, 0x76, true).await
545    }
546
547    #[fasync::run_singlethreaded(test)]
548    async fn test_block_streamer_multiple_registers() -> Result<(), Error> {
549        let src_size = 8192 * 10;
550        let dst_size = 8192;
551        let byte: u8 = 0xab;
552        let buf: Vec<u8> = vec![byte; src_size];
553        let ramdisk_client = create_ramdisk(buf).await?;
554        // TODO(https://fxbug.dev/42063787): Once ramdisk.open() no longer provides a multiplexing
555        // channel, use open() to acquire the BlockProxy here.
556        let ramdisk_controller =
557            ramdisk_client.as_controller().ok_or_else(|| anyhow!("invalid ramdisk controller"))?;
558        let (ramdisk_block, server) = fidl::endpoints::create_proxy::<BlockMarker>();
559        let () = ramdisk_controller.connect_to_device_fidl(server.into_channel())?;
560        let streamer: Box<dyn PayloadStreamer> =
561            Box::new(BlockDevicePayloadStreamer::new(ramdisk_block).await?);
562        let (proxy, server) = serve_payload(streamer, None).await?;
563
564        try_join(
565            async move {
566                let (_, vmo) = attach_vmo(dst_size, &proxy).await?;
567                assert!(vmo.is_some());
568                let (err, _) = attach_vmo(dst_size, &proxy).await?;
569                assert_eq!(err, zx::sys::ZX_ERR_ALREADY_BOUND);
570                Ok(())
571            },
572            server,
573        )
574        .await?;
575
576        Ok(())
577    }
578}