virtio_device/
chain.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Descriptor chain walking.
6//!
7//! The goal of the [`ReadableChain`] and [`WritableChain`] is to present a byte-wise view of the
8//! descriptor chain, and facilitate safe reading and writing to the chain.
9//!
10//! Although walking these chains feels similar to using an iterator, the chains do not directly
11//! implement the [`std::iter::Iterator`] trait as iterator composition works against being able to
12//! then convert a [`ReadableChain`] into a [`WritableChain`]. An iterator can be built on top of
13//! these interfaces, but it has not been done here yet.
14//!
15//! In addition to walking byte ranges via the [`next`](ReadableChain::next) or [`next_with_limit`]
16//! (ReadableChain::next_with_limit) methods, the [`Read`](std::io::Read) and [`Write`]
17//! (std::io::Write) traits are implemented for [`ReadableChain`] and [`WritableChain`]
18//! respectively.
19//!
20//! When using the [`std::io::Write`] interface for the [`WritableChain`] the amount written is
21//! tracked, alleviating the need to manually perform [`add_written`](WritableChain::add_written).
22//! Although not always appropriate depending on the particular virtio device, the
23//! [`Read`](std::io::Read)/[`Write`](std::io::Write) interfaces are therefore the preferred way to
24//! manipulate the chains.
25//!
26//! The requirement from the virtio specification that all readable descriptors occur before all
27//! writable descriptors is enforced here, with explicit types that indicate what is being walked.
28//! Transitioning from the [`ReadableChain`] to the [`WritableChain`] is an explicit operation that
29//! allows for optional checking to ensure all readable descriptors have been consumed. This allows
30//! devices to easily check if the driver is violating any protocol assumptions on descriptor
31//! layouts.
32
33use crate::mem::{DeviceRange, DriverMem, DriverRange};
34use crate::queue::{Desc, DescChain, DescChainIter, DescError, DescType, DriverNotify};
35use crate::ring::{Desc as RingDesc, DescAccess};
36use thiserror::Error;
37
38#[derive(Debug, PartialEq, Clone)]
39pub struct Remaining {
40    pub bytes: usize,
41    pub descriptors: usize,
42}
43
44/// Errors from walking a descriptor chain.
45#[derive(Error, Debug, Clone, PartialEq, Eq)]
46pub enum ChainError {
47    #[error("Error in descriptor chain: {0}")]
48    Desc(#[from] DescError),
49    #[error("Found readable descriptor after writable")]
50    ReadableAfterWritable,
51    #[error("Failed to translate descriptors driver range {0:?} into a device range")]
52    TranslateFailed(DriverRange),
53    #[error("Nested indirect chain is not supported by the virtio spec")]
54    InvalidNestedIndirectChain,
55}
56
57impl From<ChainError> for std::io::Error {
58    fn from(error: ChainError) -> Self {
59        std::io::Error::new(std::io::ErrorKind::Other, error)
60    }
61}
62
63#[derive(Debug, Clone)]
64struct IndirectDescChain<'a> {
65    range: DeviceRange<'a>,
66    next: Option<u16>,
67}
68
69impl<'a> IndirectDescChain<'a> {
70    fn new(range: DeviceRange<'a>) -> Self {
71        IndirectDescChain { range: range, next: Some(0) }
72    }
73
74    pub fn next(&mut self) -> Option<Result<Desc, DescError>> {
75        let index = self.next?;
76        match self.range.split_at(index as usize * std::mem::size_of::<RingDesc>()) {
77            None => Some(Err(DescError::InvalidIndex(index))),
78            Some((_, range)) => match range.try_ptr::<RingDesc>() {
79                None => Some(Err(DescError::InvalidIndex(index))),
80                Some(ptr) => {
81                    // * SAFETY
82                    // try_ptr guarantees that returned Some(ptr) is valid for read
83                    let desc = unsafe { ptr.read_volatile() };
84                    self.next = desc.next();
85                    Some(desc.try_into())
86                }
87            },
88        }
89    }
90}
91
92// State for a generic walker that can walk either the readable or writable portions of a
93// chain. Ideally `E` would be of type DescAccess to indicate the kind of access this is iterating
94// over, but due to current limits in const generics we have to use a bool instead. It gets
95// converted to DescAccess in expected_access.
96struct State<'a, 'b, N: DriverNotify, M, const E: bool> {
97    chain: Option<DescChain<'a, 'b, N>>,
98    iter: DescChainIter<'a, 'b, N>,
99    current: Option<Desc>,
100    mem: &'a M,
101    indirect_chain: Option<IndirectDescChain<'a>>,
102}
103
104impl<'a, 'b, N: DriverNotify, M: DriverMem, const E: bool> State<'a, 'b, N, M, E> {
105    // Hack for const generics limitation to convert bool->DescAccess.
106    fn expected_access() -> DescAccess {
107        if E {
108            DescAccess::DeviceWrite
109        } else {
110            DescAccess::DeviceRead
111        }
112    }
113
114    fn next_desc(&mut self) -> Option<Result<Desc, ChainError>> {
115        fn into_desc(desc: Result<Desc, DescError>) -> Option<Result<Desc, ChainError>> {
116            match desc {
117                Ok(desc) => Some(Ok(desc)),
118                Err(e) => Some(Err(e.into())),
119            }
120        }
121
122        match self.current.take() {
123            None => {
124                // Nothing in the current, time to read a new descriptor
125                // Let's see if we have an active indirect chain
126                if let Some(indirect_chain) = &mut self.indirect_chain {
127                    // Keep processing the indirect chain
128                    match indirect_chain.next() {
129                        None => {
130                            // Indirect chain has been fully processed
131                            self.indirect_chain = None;
132                            // Read from the normal chain
133                            into_desc(self.iter.next()?)
134                        }
135                        // Read from the indirect chain
136                        Some(desc_res) => into_desc(desc_res),
137                    }
138                } else {
139                    // Read from the normal chain
140                    into_desc(self.iter.next()?)
141                }
142            }
143            // Read the remains of the self.current
144            Some(desc) => Some(Ok(desc)),
145        }
146    }
147
148    fn next_into_indirect(
149        &mut self,
150        range: DriverRange,
151        limit: usize,
152    ) -> Option<Result<DeviceRange<'a>, ChainError>> {
153        assert!(self.current.is_none());
154        if self.indirect_chain.is_some() {
155            // Supplying the nested indirect chain violates the virtio spec
156            // Either our processing is wrong or guest driver has a bug
157            return Some(Err(ChainError::InvalidNestedIndirectChain));
158        }
159
160        match self.mem.translate(range.clone()) {
161            Some(range) => {
162                self.indirect_chain = Some(IndirectDescChain::new(range));
163                self.next_with_limit(limit)
164            }
165            None => Some(Err(ChainError::TranslateFailed(range))),
166        }
167    }
168
169    fn into_device_range(
170        &mut self,
171        access: DescAccess,
172        range: DriverRange,
173        limit: usize,
174    ) -> Option<Result<DeviceRange<'a>, ChainError>> {
175        match (Self::expected_access(), access) {
176            // If descriptor we found matches what we expected then we return as much as we can
177            // based on the requested limit.
178            (DescAccess::DeviceWrite, DescAccess::DeviceWrite)
179            | (DescAccess::DeviceRead, DescAccess::DeviceRead) => {
180                let range = if let Some((range, rest)) = range.split_at(limit) {
181                    // If we could split the range, and there is non-zero remaining, then stash the
182                    // remaining portion for later and return the range that was split.
183                    if rest.len() > 0 {
184                        self.current = Some(Desc(DescType::Direct(access), rest));
185                    }
186                    range
187                } else {
188                    // Split failed, meaning we have less than was requested so we just return all
189                    // of it.
190                    range
191                };
192                Some(self.mem.translate(range.clone()).ok_or(ChainError::TranslateFailed(range)))
193            }
194            // This is a readable descriptor, while we are expecting a writable one.
195            // This indicates a corrupt descriptor chain, so return an error.
196            (DescAccess::DeviceWrite, DescAccess::DeviceRead) => {
197                // Consume the rest of the iterator to ensure any future calls to next_with_limit
198                // fail.
199                self.iter.complete();
200                Some(Err(ChainError::ReadableAfterWritable))
201            }
202            (DescAccess::DeviceRead, DescAccess::DeviceWrite) => {
203                // Put the descriptor back as we might want to walk the writable section later.
204                self.current = Some(Desc(DescType::Direct(access), range));
205                None
206            }
207        }
208    }
209
210    fn next_with_limit(&mut self, limit: usize) -> Option<Result<DeviceRange<'a>, ChainError>> {
211        match self.next_desc()? {
212            Ok(Desc(desc_type, range)) => match desc_type {
213                DescType::Direct(access) => self.into_device_range(access, range, limit),
214                DescType::Indirect => self.next_into_indirect(range, limit),
215            },
216            Err(e) => Some(Err(e.into())),
217        }
218    }
219
220    fn remaining(&self) -> Result<Remaining, ChainError> {
221        let mut state = State::<N, M, E> {
222            chain: None,
223            mem: self.mem,
224            iter: self.iter.clone(),
225            current: self.current.clone(),
226            indirect_chain: self.indirect_chain.clone(),
227        };
228        let mut bytes = 0;
229        let mut descriptors = 0;
230        while let Some(v) = state.next_with_limit(usize::MAX) {
231            bytes += v?.len();
232            descriptors += 1;
233        }
234        Ok(Remaining { bytes, descriptors })
235    }
236}
237
238// Allow easily transforming a read chain into a write.
239impl<'a, 'b, N: DriverNotify, M> From<State<'a, 'b, N, M, false>> for State<'a, 'b, N, M, true> {
240    fn from(state: State<'a, 'b, N, M, false>) -> State<'a, 'b, N, M, true> {
241        State {
242            chain: state.chain,
243            iter: state.iter,
244            current: state.current,
245            mem: state.mem,
246            indirect_chain: state.indirect_chain,
247        }
248    }
249}
250
251/// Errors resulting from completing a chain.
252///
253/// These errors are from the optional interfaces for completing and converting chains.
254#[derive(Error, Debug, Clone, PartialEq, Eq)]
255pub enum ChainCompleteError {
256    #[error("Unexpected readable descriptor found")]
257    ReadableRemaining,
258    #[error("Unexpected writable descriptor found")]
259    WritableRemaining,
260    #[error("Chain walk error {0} when checking for descriptors")]
261    Chain(#[from] ChainError),
262}
263
264/// Access the readable portion of a descriptor chain.
265///
266/// Provides access to the read-only portion of a descriptor chain. Can be [constructed directly]
267/// (ReadableChain::new) from a [`DescChain`] and once finished with can either be dropped or
268/// converted to a [`WritableChain`] if there are writable portions as well.
269///
270/// As the [`ReadableChain`] takes ownership of the [`DescChain`] dropping the [`ReadableChain`]
271/// will automatically return the [`DescChain`] to the [`Queue`](crate::queue::Queue).
272///
273/// For devices and protocols where it is useful, the chain can also be explicitly returned via the
274/// [`return_complete`](#return_complete) method to validate full consumption of the chain.
275pub struct ReadableChain<'a, 'b, N: DriverNotify, M: DriverMem> {
276    state: State<'a, 'b, N, M, false>,
277}
278
279impl<'a, 'b, N: DriverNotify, M: DriverMem> ReadableChain<'a, 'b, N, M> {
280    /// Construct a [`ReadableChain`] from a [`DescChain`].
281    ///
282    /// Requires a reference to a [`DriverMem`] in order to perform translation into
283    /// [`DeviceRange`].
284    pub fn new(chain: DescChain<'a, 'b, N>, mem: &'a M) -> Self {
285        let iter = chain.iter();
286        ReadableChain {
287            state: State { chain: Some(chain), mem, iter, current: None, indirect_chain: None },
288        }
289    }
290
291    /// Immediately return a fully consumed chain.
292    ///
293    /// This both drops the chain, thus returning the underlying [`DescChain`] to the [`Queue`]
294    /// (crate::queue::Queue), and also checks if it was fully walked, generating an error if not.
295    /// Fully walked here means that there are no readable or writable sections that had not been
296    /// iterated over.
297    ///
298    /// For virtio queues where the device is expected to fully consume what it is sent, and there
299    /// is not expected to be anything to write, this provides a way to both check for correct
300    /// device and driver functionality.
301    pub fn return_complete(self) -> Result<(), ChainCompleteError> {
302        WritableChain::from_readable(self)?.return_complete()
303    }
304
305    /// Request the next range of readable bytes, up to a limit.
306    ///
307    /// As the [`DeviceRange`] returned here represents a contiguous range this may return a smaller
308    /// range than requested by `limit`, even if there is more readable descriptor(s) remaining. In
309    /// this way the caller is directly exposed to size of the underlying descriptors in the chain
310    /// as queued by the driver.
311    ///
312    /// A return value of `None` indicates there are no more readable descriptors, however there
313    /// may still be readable descriptors.
314    ///
315    /// Should this ever return a `Some(Err(_))` it will always yield a `None` in future calls as
316    /// the chain will be deemed corrupt. If walking and attempting to recover from corrupt chains
317    /// is desirable, beyond just reporting an error, then you must use the [`DescChain`] directly
318    /// and not this interface.
319    pub fn next_with_limit(&mut self, limit: usize) -> Option<Result<DeviceRange<'a>, ChainError>> {
320        self.state.next_with_limit(limit)
321    }
322
323    /// Request the next range of readable bytes.
324    ///
325    /// Similar to [`next_with_limit`](#next_with_limit) except limit is implicitly `usize::MAX`.
326    /// This will therefore walk the descriptors in the structure as they were provided by the
327    /// driver.
328    pub fn next(&mut self) -> Option<Result<DeviceRange<'a>, ChainError>> {
329        self.next_with_limit(usize::MAX)
330    }
331
332    /// Query readable bytes and descriptors remaining.
333    ///
334    /// Returns the number of readable bytes and descriptors remaining in the chain. This does not
335    /// imply that calling [`next_with_limit`](#next_with_limit) with the result will return that
336    /// much, see [`next_with_limit`](#next_with_limit) for more details.
337    pub fn remaining(&self) -> Result<Remaining, ChainError> {
338        self.state.remaining()
339    }
340}
341
342impl<'a, 'b, N: DriverNotify, M: DriverMem> std::io::Read for ReadableChain<'a, 'b, N, M> {
343    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
344        match self.next_with_limit(buf.len()) {
345            None => Ok(0),
346            Some(Err(e)) => Err(e.into()),
347            Some(Ok(range)) => {
348                let len = range.len();
349                assert!(len <= buf.len());
350                // This unwrap is safe as we are requesting a u8 pointer that has no alignment
351                // constraints.
352                let ptr = range.try_ptr().unwrap();
353                // In the implementation of std::io::Write for WritableChain we use libc::memmove in
354                // an attempt to ensure our copy cannot be elided. Here in the read path we do not
355                // need to make guarantees as this not MMIO memory and reading has no side effects.
356                // As such if the compiler can determine that the read data is not used, we would
357                // very much like it to elide the copy.
358                // We meet the safety requirements of copy_nonoverlapping since:
359                // * buf is a reference to a slice and assumed to be valid
360                // * ptr comes from `range`, which is a DeviceRange and is defined to be valid.
361                unsafe { std::ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), len) };
362                Ok(len)
363            }
364        }
365    }
366}
367
368/// Access the writable portion of a descriptor chain.
369///
370/// Provides access to the write-only portion of a descriptor chain. If no readable portion a
371/// [`WritableChain`] can be constructed directly from a [`DescChain`], either [generating errors]
372/// (WritiableChain::new) if there are readable portions, or [ignoring them]
373/// (WritableChain::new_ignore_readable). Otherwise [`Readable`] chain can be [converted]
374/// (WritableChain::from_readable) into a [`WritableChain`], with a similar option to
375/// [ignore any remaining readable](WritableChain::from_incomplete_readable).
376///
377/// As the [`Writable`] takes ownership of the [`DescChain`] dropping the [`WritableChain`]
378/// will automatically return the [`DescChain`] to the [`Queue`](crate::queue::Queue). To report
379/// how much was written the [`WritableChain`] has an internal counter of how much you have claimed
380/// to have written via [`add_written`](WritableChain::add_written). Walking the chain via
381/// [`next`](WritableChain::next) or [`next_with_limit`](WritableChain::next_with_limit) does not
382/// automatically increment the written counter as the [`WritableChain`] cannot assume how much of
383/// the returned range was written to.
384///
385/// Writing to the chain via the [`std::io::Write`] trait will automatically increment the written
386/// counter.
387///
388/// For devices and protocols where it is useful, the chain can also be explicitly returned via the
389/// [`return_complete`](#return_complete) method to validate the full chain was written to.
390pub struct WritableChain<'a, 'b, N: DriverNotify, M: DriverMem> {
391    state: State<'a, 'b, N, M, true>,
392    written: u32,
393}
394
395impl<'a, 'b, N: DriverNotify, M: DriverMem> WritableChain<'a, 'b, N, M> {
396    /// Construct a [`WritableChain`] from a [`DescChain`].
397    ///
398    /// Requires a reference to a [`DriverMem`] in order to perform translation into
399    /// [`DeviceRange`]. Generates an error if there are any readable portions.
400    pub fn new(chain: DescChain<'a, 'b, N>, mem: &'a M) -> Result<Self, ChainCompleteError> {
401        WritableChain::from_readable(ReadableChain::new(chain, mem))
402    }
403
404    /// Construct a [`WritableChain`] from a [`DescChain`], ignoring some errors.
405    ///
406    /// Same as [`new`](#new) but ignores any readable descriptors. It may still generate an error
407    /// as a corrupt chain may be noticed when it is walked to skip any readable descriptors.
408    pub fn new_ignore_readable(
409        chain: DescChain<'a, 'b, N>,
410        mem: &'a M,
411    ) -> Result<Self, ChainError> {
412        WritableChain::from_incomplete_readable(ReadableChain::new(chain, mem))
413    }
414
415    /// Convert a [`ReadableChain`] to a [`WritableChain`]
416    ///
417    /// Generates an error if there are still readable portions of the chain left.
418    pub fn from_readable(
419        mut readable: ReadableChain<'a, 'b, N, M>,
420    ) -> Result<Self, ChainCompleteError> {
421        match readable.next() {
422            None => Ok(()),
423            Some(Ok(_)) => Err(ChainCompleteError::ReadableRemaining),
424            Some(Err(e)) => Err(e.into()),
425        }?;
426        Ok(WritableChain { state: readable.state.into(), written: 0 })
427    }
428
429    /// Convert a [`ReadableChain`] to a [`WritableChain`]
430    ///
431    /// Skips any remaining readable descriptors to construct a [`WritableChain`]. May still
432    /// generate an error if there was a problem walking the chain.
433    pub fn from_incomplete_readable(
434        mut readable: ReadableChain<'a, 'b, N, M>,
435    ) -> Result<Self, ChainError> {
436        // Walk the readable iterator to the end, returning an error if one is found
437        while let Some(_) = readable.next().transpose()? {}
438        Ok(WritableChain { state: readable.state.into(), written: 0 })
439    }
440
441    /// Immediately return a fully consumed chain.
442    ///
443    /// Similar to [`ReadableChain::return_complete`].
444    pub fn return_complete(mut self) -> Result<(), ChainCompleteError> {
445        match self.next() {
446            None => Ok(()),
447            Some(Ok(_)) => Err(ChainCompleteError::WritableRemaining),
448            Some(Err(e)) => Err(e.into()),
449        }
450    }
451
452    /// Request the next range of readable bytes, up to a limit.
453    ///
454    /// Similar to [`ReadableChain::next_with_limit`]
455    pub fn next_with_limit(&mut self, limit: usize) -> Option<Result<DeviceRange<'a>, ChainError>> {
456        self.state.next_with_limit(limit)
457    }
458
459    /// Request the next range of readable bytes.
460    ///
461    /// Similar to [`ReadableChain::next`]
462    pub fn next(&mut self) -> Option<Result<DeviceRange<'a>, ChainError>> {
463        self.next_with_limit(usize::MAX)
464    }
465
466    /// Query writable bytes and descriptors remaining.
467    ///
468    /// Similar to [`ReadableChain::remaining`]
469    pub fn remaining(&self) -> Result<Remaining, ChainError> {
470        self.state.remaining()
471    }
472
473    /// Increments the written bytes counter.
474    ///
475    /// If descriptor ranges returned from [`next`](#next) and [`next_with_limit`](#next_with_limit)
476    /// are actually written to then the amount that is written needs to be added by calling this
477    /// method, as the [`WritableChain`] itself does not know if, or how much, might have been
478    /// returned to the returned ranges.
479    ///
480    /// Note if using the [`std::io::Write`] trait implementation to write to the chain this method
481    /// does not need to be called, as the trait implementation will call it for you. You only need
482    /// to call this if actually directly calling [`next`](#next) or [`next_with_limit`]
483    /// (#next_with_limit).
484    ///
485    /// `add_written` is cumulative and can be called multiple times. No checking of this value is
486    /// performed and it is up to the caller to choose to honor the virtio specification.
487    pub fn add_written(&mut self, written: u32) {
488        self.written += written;
489    }
490}
491
492impl<'a, 'b, N: DriverNotify, M: DriverMem> Drop for WritableChain<'a, 'b, N, M> {
493    fn drop(&mut self) {
494        self.state.chain.take().unwrap().return_written(self.written);
495    }
496}
497
498impl<'a, 'b, N: DriverNotify, M: DriverMem> std::io::Write for WritableChain<'a, 'b, N, M> {
499    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
500        match self.next_with_limit(buf.len()) {
501            None => Ok(0),
502            Some(Err(e)) => Err(e.into()),
503            Some(Ok(range)) => {
504                let len = range.len();
505                assert!(len <= buf.len());
506                // This unwrap is safe as we are requesting a u8 pointer that has no alignment
507                // constraints.
508                let ptr = range.try_mut_ptr().unwrap();
509                // We use libc::memmove over ptr::copy_nonoverlapping as ptr::copy_nonoverlapping
510                // does not provide a strong guarantee that the copy cannot be elided. Ideally we
511                // would perform a volatile copy, however volatile_copy_nonoverlapping_memory
512                // intrinsic has no stable interface, and manually writing a loop of
513                // ptr::write_volatile cannot be optimized equivalently. As such, performing an ffi
514                // call to something we know cannot elide our operation, we can thus guarantee our
515                // copy happens.
516                // The safety requirements need to satisfy for memmove are the same as
517                // ptr::copy_nonoverlapping and we this is safe since:
518                // * buf is a reference to a slice and assumed to be valid
519                // * ptr comes from `range`, which is a DeviceRange, and is defined to be valid
520                // * len is checked for both of these ranges, and so the pointers are valid for the
521                //   full range of bytes.
522                unsafe { libc::memmove(ptr, buf.as_ptr() as *const libc::c_void, len) };
523                self.add_written(len as u32);
524                Ok(len)
525            }
526        }
527    }
528    fn flush(&mut self) -> std::io::Result<()> {
529        Ok(())
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use crate::fake_queue::{Chain, IdentityDriverMem, TestQueue};
537    use std::io::{Read, Write};
538
539    fn check_read<'a>(result: Option<Result<DeviceRange<'a>, ChainError>>, expected: &[u8]) {
540        let range = result.unwrap().unwrap();
541        assert_eq!(range.len(), expected.len());
542        assert_eq!(
543            // Calling slice::from_raw_parts is valid since
544            // * This memory was allocated from a single TestDeviceRange block to become a
545            //   descriptor.
546            // * No references are hold elsewhere, mutable or otherwise. Other pointers exist, but
547            //   they will not be dereferenced for the duration we hold this as a slice.
548            // * fake_queue::ChainBuilder initialized the memory, not that types of 'u8' need any
549            //   initialization.
550            unsafe { std::slice::from_raw_parts::<u8>(range.try_ptr().unwrap(), range.len()) },
551            expected
552        );
553    }
554
555    fn check_returned(result: Option<(u64, u32)>, expected: &[u8]) {
556        let (data, len) = result.unwrap();
557        assert_eq!(len as usize, expected.len());
558        assert_eq!(
559            // See check_read for safety argument.
560            unsafe { std::slice::from_raw_parts::<u8>(data as usize as *const u8, len as usize) },
561            expected
562        );
563    }
564
565    fn test_write<'a>(result: Option<Result<DeviceRange<'a>, ChainError>>, expected: u32) {
566        let range = result.unwrap().unwrap();
567        assert_eq!(range.len(), expected as usize);
568    }
569
570    fn test_write_data<'a>(result: Option<Result<DeviceRange<'a>, ChainError>>, data: &[u8]) {
571        let range = result.unwrap().unwrap();
572        assert_eq!(range.len(), data.len());
573        // See check_read for safety argument.
574        unsafe { std::slice::from_raw_parts_mut::<u8>(range.try_mut_ptr().unwrap(), range.len()) }
575            .copy_from_slice(data);
576    }
577
578    fn test_smoke_test_body<'a>(state: &mut TestQueue<'a>, driver_mem: &'a IdentityDriverMem) {
579        {
580            let mut readable = ReadableChain::new(state.queue.next_chain().unwrap(), driver_mem);
581            assert_eq!(readable.remaining(), Ok(Remaining { bytes: 12, descriptors: 3 }));
582            check_read(readable.next(), &[1, 2, 3, 4]);
583            assert_eq!(readable.remaining(), Ok(Remaining { bytes: 8, descriptors: 2 }));
584            check_read(readable.next_with_limit(2), &[5, 6]);
585            assert_eq!(readable.remaining(), Ok(Remaining { bytes: 6, descriptors: 2 }));
586            check_read(readable.next_with_limit(200), &[7, 8]);
587            assert_eq!(readable.remaining(), Ok(Remaining { bytes: 4, descriptors: 1 }));
588            check_read(readable.next_with_limit(4), &[9, 10, 11, 12]);
589            assert_eq!(readable.remaining(), Ok(Remaining { bytes: 0, descriptors: 0 }));
590            assert!(readable.next().is_none());
591
592            let mut writable = WritableChain::from_readable(readable).unwrap();
593            test_write_data(writable.next_with_limit(3), &[1, 2, 3]);
594            test_write_data(writable.next(), &[4]);
595            test_write(writable.next(), 4);
596            assert!(writable.next().is_none());
597
598            writable.add_written(4);
599        }
600
601        let returned = state.fake_queue.next_used().unwrap();
602        assert_eq!(returned.written(), 4);
603        let mut iter = returned.data_iter();
604        check_returned(iter.next(), &[1, 2, 3, 4]);
605        assert!(iter.next().is_none());
606    }
607
608    #[test]
609    fn test_smoke_test() {
610        let driver_mem = IdentityDriverMem::new();
611        let mut state = TestQueue::new(32, &driver_mem);
612        assert!(state
613            .fake_queue
614            .publish(Chain::with_data::<u8>(
615                &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
616                &[4, 4],
617                &driver_mem
618            ))
619            .is_some());
620        test_smoke_test_body(&mut state, &driver_mem);
621    }
622
623    #[test]
624    fn test_smoke_test_indirect_chain() {
625        let driver_mem = IdentityDriverMem::new();
626        let mut state = TestQueue::new(32, &driver_mem);
627        assert!(state
628            .fake_queue
629            .publish_indirect(
630                Chain::with_data::<u8>(
631                    &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
632                    &[4, 4],
633                    &driver_mem
634                ),
635                &driver_mem
636            )
637            .is_some());
638
639        test_smoke_test_body(&mut state, &driver_mem)
640    }
641
642    fn test_io_body<'a>(state: &mut TestQueue<'a>, driver_mem: &'a IdentityDriverMem) {
643        {
644            let mut readable = ReadableChain::new(state.queue.next_chain().unwrap(), driver_mem);
645            let mut buffer: [u8; 2] = [0; 2];
646            assert!(readable.read_exact(&mut buffer).is_ok());
647            assert_eq!(&buffer, &[1, 2]);
648            check_read(readable.next_with_limit(1), &[3]);
649            let mut buffer: [u8; 5] = [0; 5];
650            assert!(readable.read_exact(&mut buffer).is_ok());
651            assert_eq!(&buffer, &[4, 5, 6, 7, 8]);
652            let mut buffer = Vec::new();
653            assert!(readable.read_to_end(&mut buffer).is_ok());
654            assert_eq!(buffer, vec![9, 10, 11, 12]);
655
656            let mut writable = WritableChain::from_readable(readable).unwrap();
657            assert!(writable.write_all(&[1, 2, 3, 4, 5]).is_ok());
658            assert!(writable.write_all(&[6, 7, 8]).is_ok());
659            assert!(writable.write_all(&[9]).is_err());
660            assert!(writable.flush().is_ok());
661        }
662        let returned = state.fake_queue.next_used().unwrap();
663        assert_eq!(returned.written(), 8);
664        let mut iter = returned.data_iter();
665        check_returned(iter.next(), &[1, 2, 3, 4]);
666        check_returned(iter.next(), &[5, 6, 7, 8]);
667        assert!(iter.next().is_none());
668    }
669
670    #[test]
671    fn test_io() {
672        let driver_mem = IdentityDriverMem::new();
673        let mut state = TestQueue::new(32, &driver_mem);
674        assert!(state
675            .fake_queue
676            .publish(Chain::with_data::<u8>(
677                &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
678                &[4, 4],
679                &driver_mem
680            ))
681            .is_some());
682        test_io_body(&mut state, &driver_mem)
683    }
684
685    #[test]
686    fn test_io_indirect_chain() {
687        let driver_mem = IdentityDriverMem::new();
688        let mut state = TestQueue::new(32, &driver_mem);
689        assert!(state
690            .fake_queue
691            .publish_indirect(
692                Chain::with_data::<u8>(
693                    &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
694                    &[4, 4],
695                    &driver_mem
696                ),
697                &driver_mem
698            )
699            .is_some());
700        test_io_body(&mut state, &driver_mem)
701    }
702
703    #[test]
704    fn test_readable_completed() {
705        let driver_mem = IdentityDriverMem::new();
706        let mut state = TestQueue::new(32, &driver_mem);
707
708        let mut test_return = |read, write, limit, expected| {
709            assert!(state
710                .fake_queue
711                .publish(Chain::with_lengths(read, write, &driver_mem))
712                .is_some());
713            let mut readable = ReadableChain::new(state.queue.next_chain().unwrap(), &driver_mem);
714            if limit == 0 {
715                assert!(readable.next().unwrap().is_ok());
716            } else {
717                assert!(readable.next_with_limit(limit).unwrap().is_ok());
718            }
719            assert_eq!(readable.return_complete(), expected);
720            assert!(state.fake_queue.next_used().is_some());
721        };
722
723        test_return(&[4], &[], 0, Ok(()));
724        test_return(&[4], &[], 4, Ok(()));
725        test_return(&[4, 2], &[], 0, Err(ChainCompleteError::ReadableRemaining));
726        test_return(&[4], &[], 2, Err(ChainCompleteError::ReadableRemaining));
727        test_return(&[4], &[4], 2, Err(ChainCompleteError::ReadableRemaining));
728        test_return(&[4], &[4], 0, Err(ChainCompleteError::WritableRemaining));
729        test_return(&[4], &[4], 4, Err(ChainCompleteError::WritableRemaining));
730    }
731
732    #[test]
733    fn test_make_writable() {
734        let driver_mem = IdentityDriverMem::new();
735        let mut state = TestQueue::new(32, &driver_mem);
736
737        assert!(state.fake_queue.publish(Chain::with_lengths(&[], &[4], &driver_mem)).is_some());
738        assert!(WritableChain::new(state.queue.next_chain().unwrap(), &driver_mem).is_ok());
739        assert!(state.fake_queue.next_used().is_some());
740
741        assert!(state.fake_queue.publish(Chain::with_lengths(&[4], &[4], &driver_mem)).is_some());
742        assert_eq!(
743            WritableChain::new(state.queue.next_chain().unwrap(), &driver_mem).err().unwrap(),
744            ChainCompleteError::ReadableRemaining
745        );
746        assert!(state.fake_queue.next_used().is_some());
747
748        assert!(state.fake_queue.publish(Chain::with_lengths(&[4], &[4], &driver_mem)).is_some());
749        assert!(WritableChain::new_ignore_readable(state.queue.next_chain().unwrap(), &driver_mem)
750            .is_ok());
751        assert!(state.fake_queue.next_used().is_some());
752    }
753
754    #[test]
755    fn test_writable_completed() {
756        let driver_mem = IdentityDriverMem::new();
757        let mut state = TestQueue::new(32, &driver_mem);
758
759        let mut test_return = |read, write, limit, expected| {
760            assert!(state
761                .fake_queue
762                .publish(Chain::with_lengths(read, write, &driver_mem))
763                .is_some());
764            let mut writable =
765                WritableChain::new(state.queue.next_chain().unwrap(), &driver_mem).unwrap();
766            if limit == 0 {
767                assert!(writable.next().unwrap().is_ok());
768            } else {
769                assert!(writable.next_with_limit(limit).unwrap().is_ok());
770            }
771            assert_eq!(writable.return_complete(), expected);
772            assert!(state.fake_queue.next_used().is_some());
773        };
774
775        test_return(&[], &[4], 0, Ok(()));
776        test_return(&[], &[4], 4, Ok(()));
777        test_return(&[], &[4, 2], 0, Err(ChainCompleteError::WritableRemaining));
778        test_return(&[], &[4], 2, Err(ChainCompleteError::WritableRemaining));
779    }
780
781    #[test]
782    fn test_bad_chain() {
783        let driver_mem = IdentityDriverMem::new();
784        let mut state = TestQueue::new(32, &driver_mem);
785
786        // Get memory for two descriptors so we can build our custom chain.
787        let desc1 = driver_mem.new_range(10).unwrap();
788        let desc2 = driver_mem.new_range(20).unwrap();
789
790        assert!(state
791            .fake_queue
792            .publish(Chain::with_exact_data(&[
793                (DescAccess::DeviceWrite, desc1.get().start as u64, desc1.len() as u32),
794                (DescAccess::DeviceRead, desc2.get().start as u64, desc2.len() as u32)
795            ]))
796            .is_some());
797
798        {
799            let mut writable =
800                WritableChain::new_ignore_readable(state.queue.next_chain().unwrap(), &driver_mem)
801                    .unwrap();
802            assert!(writable.next().unwrap().is_ok());
803            assert_eq!(writable.next().unwrap().err().unwrap(), ChainError::ReadableAfterWritable);
804        }
805        assert!(state.fake_queue.next_used().is_some());
806    }
807}