1#![deny(unsafe_op_in_unsafe_fn, missing_docs)]
7
8use std::pin::Pin;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Weak};
11use std::task::{Context, Poll};
12
13use fidl_next::protocol::NonBlockingTransport;
14use fidl_next::protocol::fuchsia::channel::Buffer;
15use fidl_next::{CHUNK_SIZE, ClientEnd, Executor, HasExecutor, ServerEnd, Transport};
16use futures::task::AtomicWaker;
17use libasync::callback_state::CallbackSharedState;
18use libasync::{JoinHandle, OnDispatcher};
19use libasync_sys::{async_begin_wait, async_dispatcher, async_wait};
20use zx::sys::{
21 ZX_CHANNEL_PEER_CLOSED, ZX_CHANNEL_READABLE, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_CANCELED,
22 ZX_ERR_PEER_CLOSED, ZX_ERR_SHOULD_WAIT, ZX_OK, zx_channel_read, zx_channel_write,
23 zx_packet_signal_t, zx_status_t,
24};
25use zx::{Channel, Status};
26
27#[derive(Debug, PartialEq)]
29pub struct AsyncChannel<D> {
30 dispatcher: D,
31 channel: Arc<Channel>,
32}
33
34impl<D> AsyncChannel<D> {
35 pub fn new_on_dispatcher(dispatcher: D, channel: Channel) -> Self {
37 Self { dispatcher, channel: Arc::new(channel) }
38 }
39
40 pub fn client_from_zx_channel_on_dispatcher<P>(
43 from: ClientEnd<P, Channel>,
44 dispatcher: D,
45 ) -> ClientEnd<P, Self> {
46 let channel = from.into_untyped();
47 ClientEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
48 }
49
50 pub fn server_from_zx_channel_on_dispatcher<P>(
53 from: ServerEnd<P, Channel>,
54 dispatcher: D,
55 ) -> ServerEnd<P, Self> {
56 let channel = from.into_untyped();
57 ServerEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
58 }
59}
60
61impl<D: Default> AsyncChannel<D> {
62 pub fn new(channel: Channel) -> Self {
65 Self::new_on_dispatcher(D::default(), channel)
66 }
67
68 pub fn client_from_zx_channel<P>(from: ClientEnd<P, Channel>) -> ClientEnd<P, Self> {
71 Self::client_from_zx_channel_on_dispatcher(from, D::default())
72 }
73
74 pub fn server_from_zx_channel<P>(from: ServerEnd<P, Channel>) -> ServerEnd<P, Self> {
77 Self::server_from_zx_channel_on_dispatcher(from, D::default())
78 }
79}
80
81impl<D: OnDispatcher> Transport for AsyncChannel<D> {
82 type Error = Status;
83 type Shared = Arc<Channel>;
84 type Exclusive = Exclusive<D>;
85 type SendBuffer = Buffer;
86 type SendFutureState = SendFutureState;
87 type RecvFutureState = RecvFutureState;
88 type RecvBuffer = Buffer;
89
90 fn split(self) -> (Self::Shared, Self::Exclusive) {
91 let channel = self.channel;
92 let object = channel.raw_handle();
93 (
94 channel.clone(),
95 Exclusive {
96 dispatcher: self.dispatcher,
97 callback_state: CallbackState::new(
98 async_wait {
99 handler: Some(RecvCallbackState::handler),
100 object,
101 trigger: ZX_CHANNEL_PEER_CLOSED | ZX_CHANNEL_READABLE,
102 ..Default::default()
103 },
104 RecvCallbackState {
105 _channel: channel,
106 canceled: AtomicBool::new(false),
107 waker: AtomicWaker::new(),
108 },
109 ),
110 },
111 )
112 }
113
114 fn acquire(_shared: &Self::Shared) -> Self::SendBuffer {
115 Buffer::new()
116 }
117
118 fn begin_send(_: &Self::Shared, buffer: Self::SendBuffer) -> Self::SendFutureState {
119 SendFutureState { buffer }
120 }
121
122 fn poll_send(
123 future_state: Pin<&mut Self::SendFutureState>,
124 _: &mut Context<'_>,
125 shared: &Self::Shared,
126 ) -> Poll<Result<(), Option<Self::Error>>> {
127 Poll::Ready(Self::send_immediately(future_state.get_mut(), shared))
128 }
129
130 fn begin_recv(
131 _shared: &Self::Shared,
132 exclusive: &mut Self::Exclusive,
133 ) -> Self::RecvFutureState {
134 RecvFutureState {
135 buffer: Some(Buffer::new()),
136 callback_state: Arc::downgrade(&exclusive.callback_state),
137 }
138 }
139
140 fn poll_recv(
141 mut future_state: Pin<&mut Self::RecvFutureState>,
142 cx: &mut Context<'_>,
143 shared: &Self::Shared,
144 exclusive: &mut Self::Exclusive,
145 ) -> Poll<Result<Self::RecvBuffer, Option<Self::Error>>> {
146 let buffer = future_state.buffer.as_mut().unwrap();
147
148 let mut actual_bytes = 0;
149 let mut actual_handles = 0;
150
151 loop {
152 let result = unsafe {
153 zx_channel_read(
154 shared.raw_handle(),
155 0,
156 buffer.chunks.as_mut_ptr().cast(),
157 buffer.handles.as_mut_ptr().cast(),
158 (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
159 buffer.handles.capacity() as u32,
160 &mut actual_bytes,
161 &mut actual_handles,
162 )
163 };
164
165 match result {
166 ZX_OK => {
167 unsafe {
168 buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
169 buffer.handles.set_len(actual_handles as usize);
170 }
171 return Poll::Ready(Ok(future_state.buffer.take().unwrap()));
172 }
173 ZX_ERR_PEER_CLOSED => return Poll::Ready(Err(None)),
174 ZX_ERR_BUFFER_TOO_SMALL => {
175 let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
176 buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
177 buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
178 }
179 ZX_ERR_SHOULD_WAIT => {
180 exclusive.wait_readable(cx)?;
181 return Poll::Pending;
182 }
183 raw => return Poll::Ready(Err(Some(Status::from_raw(raw)))),
184 }
185 }
186 }
187}
188
189impl<D: OnDispatcher> NonBlockingTransport for AsyncChannel<D> {
190 fn send_immediately(
191 future_state: &mut Self::SendFutureState,
192 shared: &Self::Shared,
193 ) -> Result<(), Option<Self::Error>> {
194 let result = unsafe {
195 zx_channel_write(
196 shared.raw_handle(),
197 0,
198 future_state.buffer.chunks.as_ptr().cast::<u8>(),
199 (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
200 future_state.buffer.handles.as_ptr().cast(),
201 future_state.buffer.handles.len() as u32,
202 )
203 };
204
205 match result {
206 ZX_OK => {
207 unsafe {
209 future_state.buffer.handles.set_len(0);
210 }
211 Ok(())
212 }
213 ZX_ERR_PEER_CLOSED => Err(None),
214 _ => Err(Some(Status::from_raw(result))),
215 }
216 }
217}
218
219pub struct FidlExecutor<D>(D);
222
223impl<D> std::ops::Deref for FidlExecutor<D> {
224 type Target = D;
225 fn deref(&self) -> &Self::Target {
226 &self.0
227 }
228}
229
230impl<D> From<D> for FidlExecutor<D> {
231 fn from(value: D) -> Self {
232 FidlExecutor(value)
233 }
234}
235
236impl<D: OnDispatcher + 'static> Executor for FidlExecutor<D> {
237 type JoinHandle<T: 'static> = JoinHandle<T>;
238
239 fn spawn<F>(&self, future: F) -> Self::JoinHandle<F::Output>
240 where
241 F: Future + Send + 'static,
242 F::Output: Send + 'static,
243 {
244 self.0.compute(future).detach_on_drop()
245 }
246}
247
248impl<D: OnDispatcher> fidl_next::RunsTransport<AsyncChannel<D>> for FidlExecutor<D> {}
249
250impl<D: OnDispatcher + 'static> HasExecutor for AsyncChannel<D> {
251 type Executor = FidlExecutor<D>;
252
253 fn executor(&self) -> Self::Executor {
254 FidlExecutor(self.dispatcher.clone())
255 }
256}
257
258type CallbackState = CallbackSharedState<async_wait, RecvCallbackState>;
259
260#[doc(hidden)] pub struct Exclusive<D> {
262 callback_state: Arc<CallbackState>,
263 dispatcher: D,
264}
265
266impl<D: OnDispatcher> Exclusive<D> {
267 fn wait_readable(&mut self, cx: &Context<'_>) -> Result<(), Status> {
268 self.callback_state.waker.register(cx.waker());
269 if self.callback_state.canceled.load(Ordering::Relaxed) {
270 return Err(Status::CANCELED);
272 }
273
274 if Arc::strong_count(&self.callback_state) > 1 {
275 return Ok(());
278 }
279 self.dispatcher.on_maybe_dispatcher(|dispatcher| {
280 let callback_state_ptr = CallbackState::make_raw_ptr(self.callback_state.clone());
281 Status::ok(unsafe { async_begin_wait(dispatcher.inner().as_ptr(), callback_state_ptr) })
283 .inspect_err(|_| {
284 unsafe { CallbackState::release_raw_ptr(callback_state_ptr) };
287 })
288 })
289 }
290}
291
292struct RecvCallbackState {
294 _channel: Arc<Channel>,
295 canceled: AtomicBool,
296 waker: AtomicWaker,
297}
298
299impl RecvCallbackState {
300 unsafe extern "C" fn handler(
301 _dispatcher: *mut async_dispatcher,
302 callback_state_ptr: *mut async_wait,
303 status: zx_status_t,
304 _packet: *const zx_packet_signal_t,
305 ) {
306 debug_assert!(
307 status == ZX_OK || status == ZX_ERR_CANCELED,
308 "task callback called with status other than ok or canceled"
309 );
310 let state = unsafe { CallbackState::from_raw_ptr(callback_state_ptr) };
313 if status == ZX_ERR_CANCELED {
314 state.canceled.store(true, Ordering::Relaxed);
315 }
316 state.waker.wake();
317 }
318}
319
320pub struct RecvFutureState {
322 buffer: Option<Buffer>,
323 callback_state: Weak<CallbackState>,
324}
325
326impl Drop for RecvFutureState {
327 fn drop(&mut self) {
328 let Some(state) = self.callback_state.upgrade() else { return };
329 state.waker.wake();
331 }
332}
333
334pub struct SendFutureState {
336 buffer: Buffer,
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use fdf::CurrentDispatcher;
343 use fdf_env::test::spawn_in_driver;
344 use fidl_next::{ClientDispatcher, ClientEnd, IgnoreEvents};
345 use fidl_next_fuchsia_examples_gizmo::Device;
346
347 #[fuchsia::test]
348 async fn wait_pending_at_dispatcher_shutdown() {
349 spawn_in_driver("driver fidl server", async {
350 let (_server_chan, client_chan) = Channel::create();
351 let client_end: ClientEnd<Device, _> = ClientEnd::<Device, _>::from_untyped(
352 AsyncChannel::new_on_dispatcher(CurrentDispatcher, client_chan),
353 );
354 let client_dispatcher = ClientDispatcher::new(client_end);
355 let _client = client_dispatcher.client();
356 CurrentDispatcher
357 .spawn(async {
358 println!(
359 "client task finished: {:?}",
360 client_dispatcher.run(IgnoreEvents).await.map(|_| ())
361 );
362 })
363 .unwrap();
364 (_server_chan, _client)
365 });
366 }
367}