1#![deny(unsafe_op_in_unsafe_fn, missing_docs)]
7
8use std::mem::replace;
9use std::pin::Pin;
10use std::ptr::NonNull;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::task::{Context, Poll};
14
15use fidl_next::decoder::InternalHandleDecoder;
16use fidl_next::encoder::InternalHandleEncoder;
17use fidl_next::fuchsia::{HandleDecoder, HandleEncoder};
18use fidl_next::protocol::NonBlockingTransport;
19use fidl_next::{
20 CHUNK_SIZE, Chunk, ClientEnd, DecodeError, Decoder, EncodeError, Encoder, Executor,
21 HasExecutor, ServerEnd, Transport,
22};
23use futures::task::AtomicWaker;
24use libasync::callback_state::CallbackSharedState;
25use libasync::{JoinHandle, OnDispatcher};
26use libasync_sys::{async_begin_wait, async_dispatcher, async_wait};
27use zx::sys::{
28 ZX_CHANNEL_PEER_CLOSED, ZX_CHANNEL_READABLE, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_CANCELED,
29 ZX_ERR_PEER_CLOSED, ZX_ERR_SHOULD_WAIT, ZX_OK, zx_channel_read, zx_channel_write, zx_handle_t,
30 zx_packet_signal_t, zx_status_t,
31};
32use zx::{AsHandleRef, Channel, NullableHandle, Status};
33
34#[derive(Debug, PartialEq)]
36pub struct AsyncChannel<D> {
37 dispatcher: D,
38 channel: Arc<Channel>,
39}
40
41impl<D> AsyncChannel<D> {
42 pub fn new_on_dispatcher(dispatcher: D, channel: Channel) -> Self {
44 Self { dispatcher, channel: Arc::new(channel) }
45 }
46
47 pub fn client_from_zx_channel_on_dispatcher<P>(
50 from: ClientEnd<P, Channel>,
51 dispatcher: D,
52 ) -> ClientEnd<P, Self> {
53 let channel = from.into_untyped();
54 ClientEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
55 }
56
57 pub fn server_from_zx_channel_on_dispatcher<P>(
60 from: ServerEnd<P, Channel>,
61 dispatcher: D,
62 ) -> ServerEnd<P, Self> {
63 let channel = from.into_untyped();
64 ServerEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
65 }
66}
67
68impl<D: Default> AsyncChannel<D> {
69 pub fn new(channel: Channel) -> Self {
72 Self::new_on_dispatcher(D::default(), channel)
73 }
74
75 pub fn client_from_zx_channel<P>(from: ClientEnd<P, Channel>) -> ClientEnd<P, Self> {
78 Self::client_from_zx_channel_on_dispatcher(from, D::default())
79 }
80
81 pub fn server_from_zx_channel<P>(from: ServerEnd<P, Channel>) -> ServerEnd<P, Self> {
84 Self::server_from_zx_channel_on_dispatcher(from, D::default())
85 }
86}
87
88impl<D: OnDispatcher> Transport for AsyncChannel<D> {
89 type Error = Status;
90 type Shared = Arc<Channel>;
91 type Exclusive = Exclusive<D>;
92 type SendBuffer = Buffer;
93 type SendFutureState = SendFutureState;
94 type RecvFutureState = RecvFutureState;
95 type RecvBuffer = RecvBuffer;
96
97 fn split(self) -> (Self::Shared, Self::Exclusive) {
98 let channel = self.channel;
99 let object = channel.raw_handle();
100 (
101 channel.clone(),
102 Exclusive {
103 dispatcher: self.dispatcher,
104 callback_state: CallbackState::new(
105 async_wait {
106 handler: Some(RecvCallbackState::handler),
107 object,
108 trigger: ZX_CHANNEL_PEER_CLOSED | ZX_CHANNEL_READABLE,
109 ..Default::default()
110 },
111 RecvCallbackState {
112 _channel: channel,
113 canceled: AtomicBool::new(false),
114 waker: AtomicWaker::new(),
115 },
116 ),
117 },
118 )
119 }
120
121 fn acquire(_shared: &Self::Shared) -> Self::SendBuffer {
122 Buffer::new()
123 }
124
125 fn begin_send(_: &Self::Shared, buffer: Self::SendBuffer) -> Self::SendFutureState {
126 SendFutureState { buffer }
127 }
128
129 fn poll_send(
130 future_state: Pin<&mut Self::SendFutureState>,
131 _: &mut Context<'_>,
132 shared: &Self::Shared,
133 ) -> Poll<Result<(), Option<Self::Error>>> {
134 Poll::Ready(Self::send_immediately(future_state.get_mut(), shared))
135 }
136
137 fn begin_recv(
138 _shared: &Self::Shared,
139 exclusive: &mut Self::Exclusive,
140 ) -> Self::RecvFutureState {
141 RecvFutureState {
142 buffer: Some(Buffer::new()),
143 callback_state: Arc::downgrade(&exclusive.callback_state),
144 }
145 }
146
147 fn poll_recv(
148 mut future_state: Pin<&mut Self::RecvFutureState>,
149 cx: &mut Context<'_>,
150 shared: &Self::Shared,
151 exclusive: &mut Self::Exclusive,
152 ) -> Poll<Result<Self::RecvBuffer, Option<Self::Error>>> {
153 let buffer = future_state.buffer.as_mut().unwrap();
154
155 let mut actual_bytes = 0;
156 let mut actual_handles = 0;
157
158 loop {
159 let result = unsafe {
160 zx_channel_read(
161 shared.raw_handle(),
162 0,
163 buffer.chunks.as_mut_ptr().cast(),
164 buffer.handles.as_mut_ptr().cast(),
165 (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
166 buffer.handles.capacity() as u32,
167 &mut actual_bytes,
168 &mut actual_handles,
169 )
170 };
171
172 match result {
173 ZX_OK => {
174 unsafe {
175 buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
176 buffer.handles.set_len(actual_handles as usize);
177 }
178 return Poll::Ready(Ok(RecvBuffer {
179 buffer: future_state.buffer.take().unwrap(),
180 chunks_taken: 0,
181 handles_taken: 0,
182 }));
183 }
184 ZX_ERR_PEER_CLOSED => return Poll::Ready(Err(None)),
185 ZX_ERR_BUFFER_TOO_SMALL => {
186 let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
187 buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
188 buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
189 }
190 ZX_ERR_SHOULD_WAIT => {
191 exclusive.wait_readable(cx)?;
192 return Poll::Pending;
193 }
194 raw => return Poll::Ready(Err(Some(Status::from_raw(raw)))),
195 }
196 }
197 }
198}
199
200impl<D: OnDispatcher> NonBlockingTransport for AsyncChannel<D> {
201 fn send_immediately(
202 future_state: &mut Self::SendFutureState,
203 shared: &Self::Shared,
204 ) -> Result<(), Option<Self::Error>> {
205 let result = unsafe {
206 zx_channel_write(
207 shared.raw_handle(),
208 0,
209 future_state.buffer.chunks.as_ptr().cast::<u8>(),
210 (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
211 future_state.buffer.handles.as_ptr().cast(),
212 future_state.buffer.handles.len() as u32,
213 )
214 };
215
216 match result {
217 ZX_OK => {
218 unsafe {
220 future_state.buffer.handles.set_len(0);
221 }
222 Ok(())
223 }
224 ZX_ERR_PEER_CLOSED => Err(None),
225 _ => Err(Some(Status::from_raw(result))),
226 }
227 }
228}
229
230pub struct FidlExecutor<D>(D);
233
234impl<D> std::ops::Deref for FidlExecutor<D> {
235 type Target = D;
236 fn deref(&self) -> &Self::Target {
237 &self.0
238 }
239}
240
241impl<D> From<D> for FidlExecutor<D> {
242 fn from(value: D) -> Self {
243 FidlExecutor(value)
244 }
245}
246
247impl<D: OnDispatcher + 'static> Executor for FidlExecutor<D> {
248 type JoinHandle<T: 'static> = JoinHandle<T>;
249
250 fn spawn<F>(&self, future: F) -> Self::JoinHandle<F::Output>
251 where
252 F: Future + Send + 'static,
253 F::Output: Send + 'static,
254 {
255 self.0.compute(future).detach_on_drop()
256 }
257}
258
259impl<D: OnDispatcher> fidl_next::RunsTransport<AsyncChannel<D>> for FidlExecutor<D> {}
260
261impl<D: OnDispatcher + 'static> HasExecutor for AsyncChannel<D> {
262 type Executor = FidlExecutor<D>;
263
264 fn executor(&self) -> Self::Executor {
265 FidlExecutor(self.dispatcher.clone())
266 }
267}
268
269type CallbackState = CallbackSharedState<async_wait, RecvCallbackState>;
270
271#[doc(hidden)] pub struct Exclusive<D> {
273 callback_state: Arc<CallbackState>,
274 dispatcher: D,
275}
276
277impl<D: OnDispatcher> Exclusive<D> {
278 fn wait_readable(&mut self, cx: &Context<'_>) -> Result<(), Status> {
279 self.callback_state.waker.register(cx.waker());
280 if self.callback_state.canceled.load(Ordering::Relaxed) {
281 return Err(Status::CANCELED);
283 }
284
285 if Arc::strong_count(&self.callback_state) > 1 {
286 return Ok(());
289 }
290 self.dispatcher.on_maybe_dispatcher(|dispatcher| {
291 let callback_state_ptr = CallbackState::make_raw_ptr(self.callback_state.clone());
292 Status::ok(unsafe { async_begin_wait(dispatcher.inner().as_ptr(), callback_state_ptr) })
294 .inspect_err(|_| {
295 unsafe { CallbackState::release_raw_ptr(callback_state_ptr) };
298 })
299 })
300 }
301}
302
303struct RecvCallbackState {
305 _channel: Arc<Channel>,
306 canceled: AtomicBool,
307 waker: AtomicWaker,
308}
309
310impl RecvCallbackState {
311 unsafe extern "C" fn handler(
312 _dispatcher: *mut async_dispatcher,
313 callback_state_ptr: *mut async_wait,
314 status: zx_status_t,
315 _packet: *const zx_packet_signal_t,
316 ) {
317 debug_assert!(
318 status == ZX_OK || status == ZX_ERR_CANCELED,
319 "task callback called with status other than ok or canceled"
320 );
321 let state = unsafe { CallbackState::from_raw_ptr(callback_state_ptr) };
324 if status == ZX_ERR_CANCELED {
325 state.canceled.store(true, Ordering::Relaxed);
326 }
327 state.waker.wake();
328 }
329}
330
331pub struct RecvFutureState {
333 buffer: Option<Buffer>,
334 callback_state: Weak<CallbackState>,
335}
336
337impl Drop for RecvFutureState {
338 fn drop(&mut self) {
339 let Some(state) = self.callback_state.upgrade() else { return };
340 state.waker.wake();
342 }
343}
344
345pub struct SendFutureState {
347 buffer: Buffer,
348}
349
350#[derive(Default)]
352pub struct Buffer {
353 handles: Vec<NullableHandle>,
354 chunks: Vec<Chunk>,
355}
356
357impl Buffer {
358 pub fn new() -> Self {
360 Self::default()
361 }
362
363 pub fn handles(&self) -> &[NullableHandle] {
365 &self.handles
366 }
367
368 pub fn bytes(&self) -> Vec<u8> {
370 self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
371 }
372
373 pub fn from_raw(handles: Vec<NullableHandle>, chunks: Vec<Chunk>) -> Self {
375 Self { handles, chunks }
376 }
377
378 pub fn from_raw_bytes(handles: Vec<NullableHandle>, bytes: impl AsRef<[u8]>) -> Self {
380 let bytes = bytes.as_ref();
381 assert!(bytes.len() % CHUNK_SIZE == 0);
382 let chunks = bytes
383 .chunks_exact(CHUNK_SIZE)
384 .map(|c| fidl_next::WireU64(u64::from_le_bytes(c.try_into().unwrap())))
385 .collect();
386 Self::from_raw(handles, chunks)
387 }
388}
389
390impl InternalHandleEncoder for Buffer {
391 #[inline]
392 fn __internal_handle_count(&self) -> usize {
393 self.handles.len()
394 }
395}
396
397impl Encoder for Buffer {
398 #[inline]
399 fn bytes_written(&self) -> usize {
400 Encoder::bytes_written(&self.chunks)
401 }
402
403 #[inline]
404 fn write_zeroes(&mut self, len: usize) {
405 Encoder::write_zeroes(&mut self.chunks, len)
406 }
407
408 #[inline]
409 fn write(&mut self, bytes: &[u8]) {
410 Encoder::write(&mut self.chunks, bytes)
411 }
412
413 #[inline]
414 fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
415 Encoder::rewrite(&mut self.chunks, pos, bytes)
416 }
417}
418
419impl HandleEncoder for Buffer {
420 fn push_handle(&mut self, handle: NullableHandle) -> Result<(), EncodeError> {
421 self.handles.push(handle);
422 Ok(())
423 }
424
425 fn handles_pushed(&self) -> usize {
426 self.handles.len()
427 }
428}
429
430pub struct RecvBuffer {
432 buffer: Buffer,
433 chunks_taken: usize,
434 handles_taken: usize,
435}
436
437impl RecvBuffer {
438 pub fn new(buffer: Buffer) -> Self {
440 Self { buffer, chunks_taken: 0, handles_taken: 0 }
441 }
442}
443
444unsafe impl Decoder for RecvBuffer {
445 fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
446 if count > self.buffer.chunks.len() - self.chunks_taken {
447 return Err(DecodeError::InsufficientData);
448 }
449
450 let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
451 self.chunks_taken += count;
452
453 unsafe { Ok(NonNull::new_unchecked(chunks)) }
454 }
455
456 fn commit(&mut self) {
457 for handle in &mut self.buffer.handles[0..self.handles_taken] {
458 let _ = replace(handle, NullableHandle::invalid()).into_raw();
460 }
461 }
462
463 fn finish(&self) -> Result<(), DecodeError> {
464 if self.chunks_taken != self.buffer.chunks.len() {
465 return Err(DecodeError::ExtraBytes {
466 num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
467 });
468 }
469
470 if self.handles_taken != self.buffer.handles.len() {
471 return Err(DecodeError::ExtraHandles {
472 num_extra: self.buffer.handles.len() - self.handles_taken,
473 });
474 }
475
476 Ok(())
477 }
478}
479
480impl InternalHandleDecoder for RecvBuffer {
481 fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
482 if count > self.buffer.handles.len() - self.handles_taken {
483 return Err(DecodeError::InsufficientHandles);
484 }
485
486 for i in self.handles_taken..self.handles_taken + count {
487 let handle = replace(&mut self.buffer.handles[i], NullableHandle::invalid());
488 drop(handle);
489 }
490 self.handles_taken += count;
491
492 Ok(())
493 }
494
495 fn __internal_handles_remaining(&self) -> usize {
496 self.buffer.handles.len() - self.handles_taken
497 }
498}
499
500impl HandleDecoder for RecvBuffer {
501 fn take_raw_handle(&mut self) -> Result<zx_handle_t, DecodeError> {
502 if self.handles_taken >= self.buffer.handles.len() {
503 return Err(DecodeError::InsufficientHandles);
504 }
505
506 let handle = self.buffer.handles[self.handles_taken].raw_handle();
507 self.handles_taken += 1;
508
509 Ok(handle)
510 }
511
512 fn handles_remaining(&mut self) -> usize {
513 self.buffer.handles.len() - self.handles_taken
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use fdf::CurrentDispatcher;
521 use fdf_env::test::spawn_in_driver;
522 use fidl_next::{ClientDispatcher, ClientEnd, IgnoreEvents};
523 use fidl_next_fuchsia_examples_gizmo::Device;
524
525 #[fuchsia::test]
526 async fn wait_pending_at_dispatcher_shutdown() {
527 spawn_in_driver("driver fidl server", async {
528 let (_server_chan, client_chan) = Channel::create();
529 let client_end: ClientEnd<Device, _> = ClientEnd::<Device, _>::from_untyped(
530 AsyncChannel::new_on_dispatcher(CurrentDispatcher, client_chan),
531 );
532 let client_dispatcher = ClientDispatcher::new(client_end);
533 let _client = client_dispatcher.client();
534 CurrentDispatcher
535 .spawn(async {
536 println!(
537 "client task finished: {:?}",
538 client_dispatcher.run(IgnoreEvents).await.map(|_| ())
539 );
540 })
541 .unwrap();
542 (_server_chan, _client)
543 });
544 }
545}