1use fuchsia_async::{self as fasync, PacketReceiver, ReceiverRegistration};
6
7use futures::channel::mpsc;
8use futures::{Stream, StreamExt, TryStreamExt};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use thiserror::Error;
12
13const QUEUE_NOTIFY_MULTIPLIER: usize = 4;
30
31#[derive(Error, Debug, PartialEq, Eq)]
32pub enum BellError {
33 #[error("Received unexpected packet {0:?}")]
34 UnexpectedPacket(zx::Packet),
35 #[error("Trap address {0:?} did not map to a queue")]
36 BadAddress(zx::GPAddr),
37}
38
39#[derive(Debug, Eq, PartialEq)]
40enum Packet {
41 Bell(zx::GPAddr),
42 Other(zx::Packet),
43}
44
45#[derive(Debug)]
47pub struct PortForwarder {
48 channel: mpsc::UnboundedSender<Packet>,
49}
50
51impl PacketReceiver for PortForwarder {
52 fn receive_packet(&self, packet: zx::Packet) {
53 let packet = if let zx::PacketContents::GuestBell(bell) = packet.contents() {
54 Packet::Bell(bell.addr())
55 } else {
56 Packet::Other(packet)
57 };
58 self.channel.unbounded_send(packet).unwrap();
63 }
64}
65
66#[derive(Debug)]
73pub struct GuestBellTrap<T = ReceiverRegistration<PortForwarder>> {
74 _registration: T,
75 channel: mpsc::UnboundedReceiver<Packet>,
76 base: zx::GPAddr,
77 num_queues: u16,
78}
79
80impl GuestBellTrap {
81 pub fn new(guest: &zx::Guest, base: zx::GPAddr, len: usize) -> Result<Self, zx::Status> {
90 let (tx, rx) = mpsc::unbounded();
91 let registration = fasync::EHandle::local()
92 .register_receiver(std::sync::Arc::new(PortForwarder { channel: tx }));
93 guest.set_trap_bell(base, len, registration.port(), registration.key())?;
94 Self::with_registration(base, len, rx, registration)
95 }
96}
97
98impl<T> GuestBellTrap<T> {
99 fn with_registration(
100 base: zx::GPAddr,
101 len: usize,
102 rx: mpsc::UnboundedReceiver<Packet>,
103 registration: T,
104 ) -> Result<Self, zx::Status> {
105 if (base.0 % QUEUE_NOTIFY_MULTIPLIER) != 0 {
107 return Err(zx::Status::INVALID_ARGS);
108 }
109 let num_queues = (len / QUEUE_NOTIFY_MULTIPLIER) as u16;
110 if num_queues as usize * QUEUE_NOTIFY_MULTIPLIER != len {
111 return Err(zx::Status::INVALID_ARGS);
112 }
113 if num_queues == 0 {
115 return Err(zx::Status::INVALID_ARGS);
116 }
117 Ok(GuestBellTrap { _registration: registration, channel: rx, base, num_queues })
118 }
119
120 pub fn queue_for_addr(&self, addr: zx::GPAddr) -> Option<u16> {
126 let queue =
127 ((addr.0.checked_sub(self.base.0)?) / QUEUE_NOTIFY_MULTIPLIER).try_into().ok()?;
128 if queue >= self.num_queues {
129 None
130 } else {
131 Some(queue)
132 }
133 }
134}
135
136impl<T: Unpin> GuestBellTrap<T> {
137 pub async fn complete<'a, N>(
142 self,
143 device: &crate::Device<'a, N>,
144 ) -> Result<(), crate::DeviceError> {
145 self.err_into()
146 .try_for_each(|queue| futures::future::ready(device.notify_queue(queue as u16)))
147 .await
148 }
149
150 pub async fn complete_or_pending<'a, N>(
157 maybe_trap: Option<Self>,
158 device: &crate::Device<'a, N>,
159 ) -> Result<(), crate::DeviceError> {
160 match maybe_trap {
161 Some(bell) => bell.complete(device).await,
162 None => futures::future::pending().await,
163 }
164 }
165}
166
167impl<T: Unpin> Stream for GuestBellTrap<T> {
168 type Item = Result<u16, BellError>;
169
170 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171 self.channel.poll_next_unpin(cx).map(|maybe_packet| {
172 let packet = maybe_packet?;
175 match packet {
176 Packet::Bell(addr) => {
177 Some(self.queue_for_addr(addr).ok_or(BellError::BadAddress(addr)))
178 }
179 Packet::Other(packet) => Some(Err(BellError::UnexpectedPacket(packet))),
180 }
181 })
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use futures::FutureExt;
189 #[test]
190 fn trap_size() {
191 assert_eq!(
193 GuestBellTrap::with_registration(zx::GPAddr(3), 4, mpsc::unbounded().1, ()).err(),
194 Some(zx::Status::INVALID_ARGS)
195 );
196 assert_eq!(
197 GuestBellTrap::with_registration(zx::GPAddr(1), 4, mpsc::unbounded().1, ()).err(),
198 Some(zx::Status::INVALID_ARGS)
199 );
200
201 assert_eq!(
203 GuestBellTrap::with_registration(zx::GPAddr(8), 0, mpsc::unbounded().1, ()).err(),
204 Some(zx::Status::INVALID_ARGS)
205 );
206
207 assert_eq!(
209 GuestBellTrap::with_registration(zx::GPAddr(8), 1, mpsc::unbounded().1, ()).err(),
210 Some(zx::Status::INVALID_ARGS)
211 );
212 assert_eq!(
213 GuestBellTrap::with_registration(zx::GPAddr(8), 3, mpsc::unbounded().1, ()).err(),
214 Some(zx::Status::INVALID_ARGS)
215 );
216 assert_eq!(
217 GuestBellTrap::with_registration(zx::GPAddr(8), 9, mpsc::unbounded().1, ()).err(),
218 Some(zx::Status::INVALID_ARGS)
219 );
220 assert_eq!(
221 GuestBellTrap::with_registration(zx::GPAddr(8), 42, mpsc::unbounded().1, ()).err(),
222 Some(zx::Status::INVALID_ARGS)
223 );
224
225 assert!(
226 GuestBellTrap::with_registration(zx::GPAddr(64), 12, mpsc::unbounded().1, ()).is_ok()
227 );
228 }
229
230 #[test]
231 fn queue_conversion() {
232 let bell =
233 GuestBellTrap::with_registration(zx::GPAddr(80), 12, mpsc::unbounded().1, ()).unwrap();
234
235 assert_eq!(bell.queue_for_addr(zx::GPAddr(79)), None);
237 assert_eq!(bell.queue_for_addr(zx::GPAddr(76)), None);
238
239 assert_eq!(bell.queue_for_addr(zx::GPAddr(80)), Some(0));
241 assert_eq!(bell.queue_for_addr(zx::GPAddr(81)), Some(0));
242 assert_eq!(bell.queue_for_addr(zx::GPAddr(83)), Some(0));
243
244 assert_eq!(bell.queue_for_addr(zx::GPAddr(84)), Some(1));
246 assert_eq!(bell.queue_for_addr(zx::GPAddr(88)), Some(2));
247 assert_eq!(bell.queue_for_addr(zx::GPAddr(91)), Some(2));
248
249 assert_eq!(bell.queue_for_addr(zx::GPAddr(92)), None);
251 assert_eq!(bell.queue_for_addr(zx::GPAddr(94)), None);
252 assert_eq!(bell.queue_for_addr(zx::GPAddr(128)), None);
253 }
254
255 #[fasync::run_until_stalled(test)]
256 async fn packet_stream() {
257 let (tx, rx) = mpsc::unbounded();
258
259 let bell = GuestBellTrap::with_registration(zx::GPAddr(64), 12, rx, ()).unwrap();
260
261 tx.unbounded_send(Packet::Bell(zx::GPAddr(64))).unwrap();
263 tx.unbounded_send(Packet::Bell(zx::GPAddr(68))).unwrap();
264 tx.unbounded_send(Packet::Bell(zx::GPAddr(100))).unwrap();
265
266 let mut stream = bell.peekable();
267 assert!(Pin::new(&mut stream).peek().now_or_never().is_some());
269
270 assert_eq!(stream.next().await, Some(Ok(0)));
272 assert_eq!(stream.next().await, Some(Ok(1)));
273 assert_eq!(stream.next().await, Some(Err(BellError::BadAddress(zx::GPAddr(100)))));
274
275 assert!(Pin::new(&mut stream).peek().now_or_never().is_none());
277 }
278}