1use crate::mem::{DeviceRange, DriverMem, DriverRange};
34use crate::queue::{Desc, DescChain, DescChainIter, DescError, DescType, DriverNotify};
35use crate::ring::{Desc as RingDesc, DescAccess};
36use thiserror::Error;
37
38#[derive(Debug, PartialEq, Clone)]
39pub struct Remaining {
40 pub bytes: usize,
41 pub descriptors: usize,
42}
43
44#[derive(Error, Debug, Clone, PartialEq, Eq)]
46pub enum ChainError {
47 #[error("Error in descriptor chain: {0}")]
48 Desc(#[from] DescError),
49 #[error("Found readable descriptor after writable")]
50 ReadableAfterWritable,
51 #[error("Failed to translate descriptors driver range {0:?} into a device range")]
52 TranslateFailed(DriverRange),
53 #[error("Nested indirect chain is not supported by the virtio spec")]
54 InvalidNestedIndirectChain,
55}
56
57impl From<ChainError> for std::io::Error {
58 fn from(error: ChainError) -> Self {
59 std::io::Error::new(std::io::ErrorKind::Other, error)
60 }
61}
62
63#[derive(Debug, Clone)]
64struct IndirectDescChain<'a> {
65 range: DeviceRange<'a>,
66 next: Option<u16>,
67}
68
69impl<'a> IndirectDescChain<'a> {
70 fn new(range: DeviceRange<'a>) -> Self {
71 IndirectDescChain { range: range, next: Some(0) }
72 }
73
74 pub fn next(&mut self) -> Option<Result<Desc, DescError>> {
75 let index = self.next?;
76 match self.range.split_at(index as usize * std::mem::size_of::<RingDesc>()) {
77 None => Some(Err(DescError::InvalidIndex(index))),
78 Some((_, range)) => match range.try_ptr::<RingDesc>() {
79 None => Some(Err(DescError::InvalidIndex(index))),
80 Some(ptr) => {
81 let desc = unsafe { ptr.read_volatile() };
84 self.next = desc.next();
85 Some(desc.try_into())
86 }
87 },
88 }
89 }
90}
91
92struct State<'a, 'b, N: DriverNotify, M, const E: bool> {
97 chain: Option<DescChain<'a, 'b, N>>,
98 iter: DescChainIter<'a, 'b, N>,
99 current: Option<Desc>,
100 mem: &'a M,
101 indirect_chain: Option<IndirectDescChain<'a>>,
102}
103
104impl<'a, 'b, N: DriverNotify, M: DriverMem, const E: bool> State<'a, 'b, N, M, E> {
105 fn expected_access() -> DescAccess {
107 if E {
108 DescAccess::DeviceWrite
109 } else {
110 DescAccess::DeviceRead
111 }
112 }
113
114 fn next_desc(&mut self) -> Option<Result<Desc, ChainError>> {
115 fn into_desc(desc: Result<Desc, DescError>) -> Option<Result<Desc, ChainError>> {
116 match desc {
117 Ok(desc) => Some(Ok(desc)),
118 Err(e) => Some(Err(e.into())),
119 }
120 }
121
122 match self.current.take() {
123 None => {
124 if let Some(indirect_chain) = &mut self.indirect_chain {
127 match indirect_chain.next() {
129 None => {
130 self.indirect_chain = None;
132 into_desc(self.iter.next()?)
134 }
135 Some(desc_res) => into_desc(desc_res),
137 }
138 } else {
139 into_desc(self.iter.next()?)
141 }
142 }
143 Some(desc) => Some(Ok(desc)),
145 }
146 }
147
148 fn next_into_indirect(
149 &mut self,
150 range: DriverRange,
151 limit: usize,
152 ) -> Option<Result<DeviceRange<'a>, ChainError>> {
153 assert!(self.current.is_none());
154 if self.indirect_chain.is_some() {
155 return Some(Err(ChainError::InvalidNestedIndirectChain));
158 }
159
160 match self.mem.translate(range.clone()) {
161 Some(range) => {
162 self.indirect_chain = Some(IndirectDescChain::new(range));
163 self.next_with_limit(limit)
164 }
165 None => Some(Err(ChainError::TranslateFailed(range))),
166 }
167 }
168
169 fn into_device_range(
170 &mut self,
171 access: DescAccess,
172 range: DriverRange,
173 limit: usize,
174 ) -> Option<Result<DeviceRange<'a>, ChainError>> {
175 match (Self::expected_access(), access) {
176 (DescAccess::DeviceWrite, DescAccess::DeviceWrite)
179 | (DescAccess::DeviceRead, DescAccess::DeviceRead) => {
180 let range = if let Some((range, rest)) = range.split_at(limit) {
181 if rest.len() > 0 {
184 self.current = Some(Desc(DescType::Direct(access), rest));
185 }
186 range
187 } else {
188 range
191 };
192 Some(self.mem.translate(range.clone()).ok_or(ChainError::TranslateFailed(range)))
193 }
194 (DescAccess::DeviceWrite, DescAccess::DeviceRead) => {
197 self.iter.complete();
200 Some(Err(ChainError::ReadableAfterWritable))
201 }
202 (DescAccess::DeviceRead, DescAccess::DeviceWrite) => {
203 self.current = Some(Desc(DescType::Direct(access), range));
205 None
206 }
207 }
208 }
209
210 fn next_with_limit(&mut self, limit: usize) -> Option<Result<DeviceRange<'a>, ChainError>> {
211 match self.next_desc()? {
212 Ok(Desc(desc_type, range)) => match desc_type {
213 DescType::Direct(access) => self.into_device_range(access, range, limit),
214 DescType::Indirect => self.next_into_indirect(range, limit),
215 },
216 Err(e) => Some(Err(e.into())),
217 }
218 }
219
220 fn remaining(&self) -> Result<Remaining, ChainError> {
221 let mut state = State::<N, M, E> {
222 chain: None,
223 mem: self.mem,
224 iter: self.iter.clone(),
225 current: self.current.clone(),
226 indirect_chain: self.indirect_chain.clone(),
227 };
228 let mut bytes = 0;
229 let mut descriptors = 0;
230 while let Some(v) = state.next_with_limit(usize::MAX) {
231 bytes += v?.len();
232 descriptors += 1;
233 }
234 Ok(Remaining { bytes, descriptors })
235 }
236}
237
238impl<'a, 'b, N: DriverNotify, M> From<State<'a, 'b, N, M, false>> for State<'a, 'b, N, M, true> {
240 fn from(state: State<'a, 'b, N, M, false>) -> State<'a, 'b, N, M, true> {
241 State {
242 chain: state.chain,
243 iter: state.iter,
244 current: state.current,
245 mem: state.mem,
246 indirect_chain: state.indirect_chain,
247 }
248 }
249}
250
251#[derive(Error, Debug, Clone, PartialEq, Eq)]
255pub enum ChainCompleteError {
256 #[error("Unexpected readable descriptor found")]
257 ReadableRemaining,
258 #[error("Unexpected writable descriptor found")]
259 WritableRemaining,
260 #[error("Chain walk error {0} when checking for descriptors")]
261 Chain(#[from] ChainError),
262}
263
264pub struct ReadableChain<'a, 'b, N: DriverNotify, M: DriverMem> {
276 state: State<'a, 'b, N, M, false>,
277}
278
279impl<'a, 'b, N: DriverNotify, M: DriverMem> ReadableChain<'a, 'b, N, M> {
280 pub fn new(chain: DescChain<'a, 'b, N>, mem: &'a M) -> Self {
285 let iter = chain.iter();
286 ReadableChain {
287 state: State { chain: Some(chain), mem, iter, current: None, indirect_chain: None },
288 }
289 }
290
291 pub fn return_complete(self) -> Result<(), ChainCompleteError> {
302 WritableChain::from_readable(self)?.return_complete()
303 }
304
305 pub fn next_with_limit(&mut self, limit: usize) -> Option<Result<DeviceRange<'a>, ChainError>> {
320 self.state.next_with_limit(limit)
321 }
322
323 pub fn next(&mut self) -> Option<Result<DeviceRange<'a>, ChainError>> {
329 self.next_with_limit(usize::MAX)
330 }
331
332 pub fn remaining(&self) -> Result<Remaining, ChainError> {
338 self.state.remaining()
339 }
340}
341
342impl<'a, 'b, N: DriverNotify, M: DriverMem> std::io::Read for ReadableChain<'a, 'b, N, M> {
343 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
344 match self.next_with_limit(buf.len()) {
345 None => Ok(0),
346 Some(Err(e)) => Err(e.into()),
347 Some(Ok(range)) => {
348 let len = range.len();
349 assert!(len <= buf.len());
350 let ptr = range.try_ptr().unwrap();
353 unsafe { std::ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), len) };
362 Ok(len)
363 }
364 }
365 }
366}
367
368pub struct WritableChain<'a, 'b, N: DriverNotify, M: DriverMem> {
391 state: State<'a, 'b, N, M, true>,
392 written: u32,
393}
394
395impl<'a, 'b, N: DriverNotify, M: DriverMem> WritableChain<'a, 'b, N, M> {
396 pub fn new(chain: DescChain<'a, 'b, N>, mem: &'a M) -> Result<Self, ChainCompleteError> {
401 WritableChain::from_readable(ReadableChain::new(chain, mem))
402 }
403
404 pub fn new_ignore_readable(
409 chain: DescChain<'a, 'b, N>,
410 mem: &'a M,
411 ) -> Result<Self, ChainError> {
412 WritableChain::from_incomplete_readable(ReadableChain::new(chain, mem))
413 }
414
415 pub fn from_readable(
419 mut readable: ReadableChain<'a, 'b, N, M>,
420 ) -> Result<Self, ChainCompleteError> {
421 match readable.next() {
422 None => Ok(()),
423 Some(Ok(_)) => Err(ChainCompleteError::ReadableRemaining),
424 Some(Err(e)) => Err(e.into()),
425 }?;
426 Ok(WritableChain { state: readable.state.into(), written: 0 })
427 }
428
429 pub fn from_incomplete_readable(
434 mut readable: ReadableChain<'a, 'b, N, M>,
435 ) -> Result<Self, ChainError> {
436 while let Some(_) = readable.next().transpose()? {}
438 Ok(WritableChain { state: readable.state.into(), written: 0 })
439 }
440
441 pub fn return_complete(mut self) -> Result<(), ChainCompleteError> {
445 match self.next() {
446 None => Ok(()),
447 Some(Ok(_)) => Err(ChainCompleteError::WritableRemaining),
448 Some(Err(e)) => Err(e.into()),
449 }
450 }
451
452 pub fn next_with_limit(&mut self, limit: usize) -> Option<Result<DeviceRange<'a>, ChainError>> {
456 self.state.next_with_limit(limit)
457 }
458
459 pub fn next(&mut self) -> Option<Result<DeviceRange<'a>, ChainError>> {
463 self.next_with_limit(usize::MAX)
464 }
465
466 pub fn remaining(&self) -> Result<Remaining, ChainError> {
470 self.state.remaining()
471 }
472
473 pub fn add_written(&mut self, written: u32) {
488 self.written += written;
489 }
490}
491
492impl<'a, 'b, N: DriverNotify, M: DriverMem> Drop for WritableChain<'a, 'b, N, M> {
493 fn drop(&mut self) {
494 self.state.chain.take().unwrap().return_written(self.written);
495 }
496}
497
498impl<'a, 'b, N: DriverNotify, M: DriverMem> std::io::Write for WritableChain<'a, 'b, N, M> {
499 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
500 match self.next_with_limit(buf.len()) {
501 None => Ok(0),
502 Some(Err(e)) => Err(e.into()),
503 Some(Ok(range)) => {
504 let len = range.len();
505 assert!(len <= buf.len());
506 let ptr = range.try_mut_ptr().unwrap();
509 unsafe { libc::memmove(ptr, buf.as_ptr() as *const libc::c_void, len) };
523 self.add_written(len as u32);
524 Ok(len)
525 }
526 }
527 }
528 fn flush(&mut self) -> std::io::Result<()> {
529 Ok(())
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::fake_queue::{Chain, IdentityDriverMem, TestQueue};
537 use std::io::{Read, Write};
538
539 fn check_read<'a>(result: Option<Result<DeviceRange<'a>, ChainError>>, expected: &[u8]) {
540 let range = result.unwrap().unwrap();
541 assert_eq!(range.len(), expected.len());
542 assert_eq!(
543 unsafe { std::slice::from_raw_parts::<u8>(range.try_ptr().unwrap(), range.len()) },
551 expected
552 );
553 }
554
555 fn check_returned(result: Option<(u64, u32)>, expected: &[u8]) {
556 let (data, len) = result.unwrap();
557 assert_eq!(len as usize, expected.len());
558 assert_eq!(
559 unsafe { std::slice::from_raw_parts::<u8>(data as usize as *const u8, len as usize) },
561 expected
562 );
563 }
564
565 fn test_write<'a>(result: Option<Result<DeviceRange<'a>, ChainError>>, expected: u32) {
566 let range = result.unwrap().unwrap();
567 assert_eq!(range.len(), expected as usize);
568 }
569
570 fn test_write_data<'a>(result: Option<Result<DeviceRange<'a>, ChainError>>, data: &[u8]) {
571 let range = result.unwrap().unwrap();
572 assert_eq!(range.len(), data.len());
573 unsafe { std::slice::from_raw_parts_mut::<u8>(range.try_mut_ptr().unwrap(), range.len()) }
575 .copy_from_slice(data);
576 }
577
578 fn test_smoke_test_body<'a>(state: &mut TestQueue<'a>, driver_mem: &'a IdentityDriverMem) {
579 {
580 let mut readable = ReadableChain::new(state.queue.next_chain().unwrap(), driver_mem);
581 assert_eq!(readable.remaining(), Ok(Remaining { bytes: 12, descriptors: 3 }));
582 check_read(readable.next(), &[1, 2, 3, 4]);
583 assert_eq!(readable.remaining(), Ok(Remaining { bytes: 8, descriptors: 2 }));
584 check_read(readable.next_with_limit(2), &[5, 6]);
585 assert_eq!(readable.remaining(), Ok(Remaining { bytes: 6, descriptors: 2 }));
586 check_read(readable.next_with_limit(200), &[7, 8]);
587 assert_eq!(readable.remaining(), Ok(Remaining { bytes: 4, descriptors: 1 }));
588 check_read(readable.next_with_limit(4), &[9, 10, 11, 12]);
589 assert_eq!(readable.remaining(), Ok(Remaining { bytes: 0, descriptors: 0 }));
590 assert!(readable.next().is_none());
591
592 let mut writable = WritableChain::from_readable(readable).unwrap();
593 test_write_data(writable.next_with_limit(3), &[1, 2, 3]);
594 test_write_data(writable.next(), &[4]);
595 test_write(writable.next(), 4);
596 assert!(writable.next().is_none());
597
598 writable.add_written(4);
599 }
600
601 let returned = state.fake_queue.next_used().unwrap();
602 assert_eq!(returned.written(), 4);
603 let mut iter = returned.data_iter();
604 check_returned(iter.next(), &[1, 2, 3, 4]);
605 assert!(iter.next().is_none());
606 }
607
608 #[test]
609 fn test_smoke_test() {
610 let driver_mem = IdentityDriverMem::new();
611 let mut state = TestQueue::new(32, &driver_mem);
612 assert!(state
613 .fake_queue
614 .publish(Chain::with_data::<u8>(
615 &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
616 &[4, 4],
617 &driver_mem
618 ))
619 .is_some());
620 test_smoke_test_body(&mut state, &driver_mem);
621 }
622
623 #[test]
624 fn test_smoke_test_indirect_chain() {
625 let driver_mem = IdentityDriverMem::new();
626 let mut state = TestQueue::new(32, &driver_mem);
627 assert!(state
628 .fake_queue
629 .publish_indirect(
630 Chain::with_data::<u8>(
631 &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
632 &[4, 4],
633 &driver_mem
634 ),
635 &driver_mem
636 )
637 .is_some());
638
639 test_smoke_test_body(&mut state, &driver_mem)
640 }
641
642 fn test_io_body<'a>(state: &mut TestQueue<'a>, driver_mem: &'a IdentityDriverMem) {
643 {
644 let mut readable = ReadableChain::new(state.queue.next_chain().unwrap(), driver_mem);
645 let mut buffer: [u8; 2] = [0; 2];
646 assert!(readable.read_exact(&mut buffer).is_ok());
647 assert_eq!(&buffer, &[1, 2]);
648 check_read(readable.next_with_limit(1), &[3]);
649 let mut buffer: [u8; 5] = [0; 5];
650 assert!(readable.read_exact(&mut buffer).is_ok());
651 assert_eq!(&buffer, &[4, 5, 6, 7, 8]);
652 let mut buffer = Vec::new();
653 assert!(readable.read_to_end(&mut buffer).is_ok());
654 assert_eq!(buffer, vec![9, 10, 11, 12]);
655
656 let mut writable = WritableChain::from_readable(readable).unwrap();
657 assert!(writable.write_all(&[1, 2, 3, 4, 5]).is_ok());
658 assert!(writable.write_all(&[6, 7, 8]).is_ok());
659 assert!(writable.write_all(&[9]).is_err());
660 assert!(writable.flush().is_ok());
661 }
662 let returned = state.fake_queue.next_used().unwrap();
663 assert_eq!(returned.written(), 8);
664 let mut iter = returned.data_iter();
665 check_returned(iter.next(), &[1, 2, 3, 4]);
666 check_returned(iter.next(), &[5, 6, 7, 8]);
667 assert!(iter.next().is_none());
668 }
669
670 #[test]
671 fn test_io() {
672 let driver_mem = IdentityDriverMem::new();
673 let mut state = TestQueue::new(32, &driver_mem);
674 assert!(state
675 .fake_queue
676 .publish(Chain::with_data::<u8>(
677 &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
678 &[4, 4],
679 &driver_mem
680 ))
681 .is_some());
682 test_io_body(&mut state, &driver_mem)
683 }
684
685 #[test]
686 fn test_io_indirect_chain() {
687 let driver_mem = IdentityDriverMem::new();
688 let mut state = TestQueue::new(32, &driver_mem);
689 assert!(state
690 .fake_queue
691 .publish_indirect(
692 Chain::with_data::<u8>(
693 &[&[1, 2, 3, 4], &[5, 6, 7, 8], &[9, 10, 11, 12]],
694 &[4, 4],
695 &driver_mem
696 ),
697 &driver_mem
698 )
699 .is_some());
700 test_io_body(&mut state, &driver_mem)
701 }
702
703 #[test]
704 fn test_readable_completed() {
705 let driver_mem = IdentityDriverMem::new();
706 let mut state = TestQueue::new(32, &driver_mem);
707
708 let mut test_return = |read, write, limit, expected| {
709 assert!(state
710 .fake_queue
711 .publish(Chain::with_lengths(read, write, &driver_mem))
712 .is_some());
713 let mut readable = ReadableChain::new(state.queue.next_chain().unwrap(), &driver_mem);
714 if limit == 0 {
715 assert!(readable.next().unwrap().is_ok());
716 } else {
717 assert!(readable.next_with_limit(limit).unwrap().is_ok());
718 }
719 assert_eq!(readable.return_complete(), expected);
720 assert!(state.fake_queue.next_used().is_some());
721 };
722
723 test_return(&[4], &[], 0, Ok(()));
724 test_return(&[4], &[], 4, Ok(()));
725 test_return(&[4, 2], &[], 0, Err(ChainCompleteError::ReadableRemaining));
726 test_return(&[4], &[], 2, Err(ChainCompleteError::ReadableRemaining));
727 test_return(&[4], &[4], 2, Err(ChainCompleteError::ReadableRemaining));
728 test_return(&[4], &[4], 0, Err(ChainCompleteError::WritableRemaining));
729 test_return(&[4], &[4], 4, Err(ChainCompleteError::WritableRemaining));
730 }
731
732 #[test]
733 fn test_make_writable() {
734 let driver_mem = IdentityDriverMem::new();
735 let mut state = TestQueue::new(32, &driver_mem);
736
737 assert!(state.fake_queue.publish(Chain::with_lengths(&[], &[4], &driver_mem)).is_some());
738 assert!(WritableChain::new(state.queue.next_chain().unwrap(), &driver_mem).is_ok());
739 assert!(state.fake_queue.next_used().is_some());
740
741 assert!(state.fake_queue.publish(Chain::with_lengths(&[4], &[4], &driver_mem)).is_some());
742 assert_eq!(
743 WritableChain::new(state.queue.next_chain().unwrap(), &driver_mem).err().unwrap(),
744 ChainCompleteError::ReadableRemaining
745 );
746 assert!(state.fake_queue.next_used().is_some());
747
748 assert!(state.fake_queue.publish(Chain::with_lengths(&[4], &[4], &driver_mem)).is_some());
749 assert!(WritableChain::new_ignore_readable(state.queue.next_chain().unwrap(), &driver_mem)
750 .is_ok());
751 assert!(state.fake_queue.next_used().is_some());
752 }
753
754 #[test]
755 fn test_writable_completed() {
756 let driver_mem = IdentityDriverMem::new();
757 let mut state = TestQueue::new(32, &driver_mem);
758
759 let mut test_return = |read, write, limit, expected| {
760 assert!(state
761 .fake_queue
762 .publish(Chain::with_lengths(read, write, &driver_mem))
763 .is_some());
764 let mut writable =
765 WritableChain::new(state.queue.next_chain().unwrap(), &driver_mem).unwrap();
766 if limit == 0 {
767 assert!(writable.next().unwrap().is_ok());
768 } else {
769 assert!(writable.next_with_limit(limit).unwrap().is_ok());
770 }
771 assert_eq!(writable.return_complete(), expected);
772 assert!(state.fake_queue.next_used().is_some());
773 };
774
775 test_return(&[], &[4], 0, Ok(()));
776 test_return(&[], &[4], 4, Ok(()));
777 test_return(&[], &[4, 2], 0, Err(ChainCompleteError::WritableRemaining));
778 test_return(&[], &[4], 2, Err(ChainCompleteError::WritableRemaining));
779 }
780
781 #[test]
782 fn test_bad_chain() {
783 let driver_mem = IdentityDriverMem::new();
784 let mut state = TestQueue::new(32, &driver_mem);
785
786 let desc1 = driver_mem.new_range(10).unwrap();
788 let desc2 = driver_mem.new_range(20).unwrap();
789
790 assert!(state
791 .fake_queue
792 .publish(Chain::with_exact_data(&[
793 (DescAccess::DeviceWrite, desc1.get().start as u64, desc1.len() as u32),
794 (DescAccess::DeviceRead, desc2.get().start as u64, desc2.len() as u32)
795 ]))
796 .is_some());
797
798 {
799 let mut writable =
800 WritableChain::new_ignore_readable(state.queue.next_chain().unwrap(), &driver_mem)
801 .unwrap();
802 assert!(writable.next().unwrap().is_ok());
803 assert_eq!(writable.next().unwrap().err().unwrap(), ChainError::ReadableAfterWritable);
804 }
805 assert!(state.fake_queue.next_used().is_some());
806 }
807}