1#![deny(unsafe_op_in_unsafe_fn, missing_docs)]
7
8use std::mem::replace;
9use std::pin::Pin;
10use std::ptr::NonNull;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::task::{Context, Poll};
14
15use fidl_next::decoder::InternalHandleDecoder;
16use fidl_next::encoder::InternalHandleEncoder;
17use fidl_next::fuchsia::{HandleDecoder, HandleEncoder};
18use fidl_next::protocol::NonBlockingTransport;
19use fidl_next::{
20 CHUNK_SIZE, Chunk, ClientEnd, DecodeError, Decoder, EncodeError, Encoder, Executor,
21 HasExecutor, ServerEnd, Transport,
22};
23use futures::task::AtomicWaker;
24use libasync::callback_state::CallbackSharedState;
25use libasync::{OnDispatcher, Task};
26use libasync_sys::{async_begin_wait, async_dispatcher, async_wait};
27use zx::sys::{
28 ZX_CHANNEL_PEER_CLOSED, ZX_CHANNEL_READABLE, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_CANCELED,
29 ZX_ERR_PEER_CLOSED, ZX_ERR_SHOULD_WAIT, ZX_OK, zx_channel_read, zx_channel_write, zx_handle_t,
30 zx_packet_signal_t, zx_status_t,
31};
32use zx::{AsHandleRef, Channel, HandleBased, NullableHandle, Status};
33
34#[derive(Debug, PartialEq)]
36pub struct AsyncChannel<D> {
37 dispatcher: D,
38 channel: Arc<Channel>,
39}
40
41impl<D> AsyncChannel<D> {
42 pub fn new_on_dispatcher(dispatcher: D, channel: Channel) -> Self {
44 Self { dispatcher, channel: Arc::new(channel) }
45 }
46
47 pub fn client_from_zx_channel_on_dispatcher<P>(
50 from: ClientEnd<P, Channel>,
51 dispatcher: D,
52 ) -> ClientEnd<P, Self> {
53 let channel = from.into_untyped();
54 ClientEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
55 }
56
57 pub fn server_from_zx_channel_on_dispatcher<P>(
60 from: ServerEnd<P, Channel>,
61 dispatcher: D,
62 ) -> ServerEnd<P, Self> {
63 let channel = from.into_untyped();
64 ServerEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
65 }
66}
67
68impl<D: Default> AsyncChannel<D> {
69 pub fn new(channel: Channel) -> Self {
72 Self::new_on_dispatcher(D::default(), channel)
73 }
74
75 pub fn client_from_zx_channel<P>(from: ClientEnd<P, Channel>) -> ClientEnd<P, Self> {
78 Self::client_from_zx_channel_on_dispatcher(from, D::default())
79 }
80
81 pub fn server_from_zx_channel<P>(from: ServerEnd<P, Channel>) -> ServerEnd<P, Self> {
84 Self::server_from_zx_channel_on_dispatcher(from, D::default())
85 }
86}
87
88impl<D: OnDispatcher> Transport for AsyncChannel<D> {
89 type Error = Status;
90 type Shared = Arc<Channel>;
91 type Exclusive = Exclusive<D>;
92 type SendBuffer = Buffer;
93 type SendFutureState = SendFutureState;
94 type RecvFutureState = RecvFutureState;
95 type RecvBuffer = RecvBuffer;
96
97 fn split(self) -> (Self::Shared, Self::Exclusive) {
98 let channel = self.channel;
99 let object = channel.raw_handle();
100 (
101 channel.clone(),
102 Exclusive {
103 dispatcher: self.dispatcher,
104 callback_state: CallbackState::new(
105 async_wait {
106 handler: Some(RecvCallbackState::handler),
107 object,
108 trigger: ZX_CHANNEL_PEER_CLOSED | ZX_CHANNEL_READABLE,
109 ..Default::default()
110 },
111 RecvCallbackState {
112 _channel: channel,
113 canceled: AtomicBool::new(false),
114 waker: AtomicWaker::new(),
115 },
116 ),
117 },
118 )
119 }
120
121 fn acquire(_shared: &Self::Shared) -> Self::SendBuffer {
122 Buffer::new()
123 }
124
125 fn begin_send(_: &Self::Shared, buffer: Self::SendBuffer) -> Self::SendFutureState {
126 SendFutureState { buffer }
127 }
128
129 fn poll_send(
130 future_state: Pin<&mut Self::SendFutureState>,
131 _: &mut Context<'_>,
132 shared: &Self::Shared,
133 ) -> Poll<Result<(), Option<Self::Error>>> {
134 Poll::Ready(Self::send_immediately(future_state.get_mut(), shared))
135 }
136
137 fn begin_recv(
138 _shared: &Self::Shared,
139 exclusive: &mut Self::Exclusive,
140 ) -> Self::RecvFutureState {
141 RecvFutureState {
142 buffer: Some(Buffer::new()),
143 callback_state: Arc::downgrade(&exclusive.callback_state),
144 }
145 }
146
147 fn poll_recv(
148 mut future_state: Pin<&mut Self::RecvFutureState>,
149 cx: &mut Context<'_>,
150 shared: &Self::Shared,
151 exclusive: &mut Self::Exclusive,
152 ) -> Poll<Result<Self::RecvBuffer, Option<Self::Error>>> {
153 let buffer = future_state.buffer.as_mut().unwrap();
154
155 let mut actual_bytes = 0;
156 let mut actual_handles = 0;
157
158 loop {
159 let result = unsafe {
160 zx_channel_read(
161 shared.raw_handle(),
162 0,
163 buffer.chunks.as_mut_ptr().cast(),
164 buffer.handles.as_mut_ptr().cast(),
165 (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
166 buffer.handles.capacity() as u32,
167 &mut actual_bytes,
168 &mut actual_handles,
169 )
170 };
171
172 match result {
173 ZX_OK => {
174 unsafe {
175 buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
176 buffer.handles.set_len(actual_handles as usize);
177 }
178 return Poll::Ready(Ok(RecvBuffer {
179 buffer: future_state.buffer.take().unwrap(),
180 chunks_taken: 0,
181 handles_taken: 0,
182 }));
183 }
184 ZX_ERR_PEER_CLOSED => return Poll::Ready(Err(None)),
185 ZX_ERR_BUFFER_TOO_SMALL => {
186 let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
187 buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
188 buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
189 }
190 ZX_ERR_SHOULD_WAIT => {
191 exclusive.wait_readable(cx)?;
192 return Poll::Pending;
193 }
194 raw => return Poll::Ready(Err(Some(Status::from_raw(raw)))),
195 }
196 }
197 }
198}
199
200impl<D: OnDispatcher> NonBlockingTransport for AsyncChannel<D> {
201 fn send_immediately(
202 future_state: &mut Self::SendFutureState,
203 shared: &Self::Shared,
204 ) -> Result<(), Option<Self::Error>> {
205 let result = unsafe {
206 zx_channel_write(
207 shared.raw_handle(),
208 0,
209 future_state.buffer.chunks.as_ptr().cast::<u8>(),
210 (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
211 future_state.buffer.handles.as_ptr().cast(),
212 future_state.buffer.handles.len() as u32,
213 )
214 };
215
216 match result {
217 ZX_OK => {
218 unsafe {
220 future_state.buffer.handles.set_len(0);
221 }
222 Ok(())
223 }
224 ZX_ERR_PEER_CLOSED => Err(None),
225 _ => Err(Some(Status::from_raw(result))),
226 }
227 }
228}
229
230pub struct FidlExecutor<D>(D);
233
234impl<D> std::ops::Deref for FidlExecutor<D> {
235 type Target = D;
236 fn deref(&self) -> &Self::Target {
237 &self.0
238 }
239}
240
241impl<D> From<D> for FidlExecutor<D> {
242 fn from(value: D) -> Self {
243 FidlExecutor(value)
244 }
245}
246
247impl<D: OnDispatcher + 'static> Executor for FidlExecutor<D> {
248 type Task<T: 'static> = Task<T>;
249
250 fn spawn<F>(&self, future: F) -> Self::Task<F::Output>
251 where
252 F: Future + Send + 'static,
253 F::Output: Send + 'static,
254 {
255 self.0.compute(future)
256 }
257
258 fn detach<T: 'static>(&self, task: Self::Task<T>) {
259 task.detach()
260 }
261}
262
263impl<D: OnDispatcher> fidl_next::RunsTransport<AsyncChannel<D>> for FidlExecutor<D> {}
264
265impl<D: OnDispatcher + 'static> HasExecutor for AsyncChannel<D> {
266 type Executor = FidlExecutor<D>;
267
268 fn executor(&self) -> Self::Executor {
269 FidlExecutor(self.dispatcher.clone())
270 }
271}
272
273type CallbackState = CallbackSharedState<async_wait, RecvCallbackState>;
274
275#[doc(hidden)] pub struct Exclusive<D> {
277 callback_state: Arc<CallbackState>,
278 dispatcher: D,
279}
280
281impl<D: OnDispatcher> Exclusive<D> {
282 fn wait_readable(&mut self, cx: &Context<'_>) -> Result<(), Status> {
283 self.callback_state.waker.register(cx.waker());
284 if self.callback_state.canceled.load(Ordering::Relaxed) {
285 return Err(Status::CANCELED);
287 }
288
289 if Arc::strong_count(&self.callback_state) > 1 {
290 return Ok(());
293 }
294 self.dispatcher.on_maybe_dispatcher(|dispatcher| {
295 let callback_state_ptr = CallbackState::make_raw_ptr(self.callback_state.clone());
296 Status::ok(unsafe { async_begin_wait(dispatcher.inner().as_ptr(), callback_state_ptr) })
298 .inspect_err(|_| {
299 unsafe { CallbackState::release_raw_ptr(callback_state_ptr) };
302 })
303 })
304 }
305}
306
307struct RecvCallbackState {
309 _channel: Arc<Channel>,
310 canceled: AtomicBool,
311 waker: AtomicWaker,
312}
313
314impl RecvCallbackState {
315 unsafe extern "C" fn handler(
316 _dispatcher: *mut async_dispatcher,
317 callback_state_ptr: *mut async_wait,
318 status: zx_status_t,
319 _packet: *const zx_packet_signal_t,
320 ) {
321 debug_assert!(
322 status == ZX_OK || status == ZX_ERR_CANCELED,
323 "task callback called with status other than ok or canceled"
324 );
325 let state = unsafe { CallbackState::from_raw_ptr(callback_state_ptr) };
328 if status == ZX_ERR_CANCELED {
329 state.canceled.store(true, Ordering::Relaxed);
330 }
331 state.waker.wake();
332 }
333}
334
335pub struct RecvFutureState {
337 buffer: Option<Buffer>,
338 callback_state: Weak<CallbackState>,
339}
340
341impl Drop for RecvFutureState {
342 fn drop(&mut self) {
343 let Some(state) = self.callback_state.upgrade() else { return };
344 state.waker.wake();
346 }
347}
348
349pub struct SendFutureState {
351 buffer: Buffer,
352}
353
354#[derive(Default)]
356pub struct Buffer {
357 handles: Vec<NullableHandle>,
358 chunks: Vec<Chunk>,
359}
360
361impl Buffer {
362 pub fn new() -> Self {
364 Self::default()
365 }
366
367 pub fn handles(&self) -> &[NullableHandle] {
369 &self.handles
370 }
371
372 pub fn bytes(&self) -> Vec<u8> {
374 self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
375 }
376
377 pub fn from_raw(handles: Vec<NullableHandle>, chunks: Vec<Chunk>) -> Self {
379 Self { handles, chunks }
380 }
381
382 pub fn from_raw_bytes(handles: Vec<NullableHandle>, bytes: impl AsRef<[u8]>) -> Self {
384 let bytes = bytes.as_ref();
385 assert!(bytes.len() % CHUNK_SIZE == 0);
386 let chunks = bytes
387 .chunks_exact(CHUNK_SIZE)
388 .map(|c| fidl_next::WireU64(u64::from_le_bytes(c.try_into().unwrap())))
389 .collect();
390 Self::from_raw(handles, chunks)
391 }
392}
393
394impl InternalHandleEncoder for Buffer {
395 #[inline]
396 fn __internal_handle_count(&self) -> usize {
397 self.handles.len()
398 }
399}
400
401impl Encoder for Buffer {
402 #[inline]
403 fn bytes_written(&self) -> usize {
404 Encoder::bytes_written(&self.chunks)
405 }
406
407 #[inline]
408 fn write_zeroes(&mut self, len: usize) {
409 Encoder::write_zeroes(&mut self.chunks, len)
410 }
411
412 #[inline]
413 fn write(&mut self, bytes: &[u8]) {
414 Encoder::write(&mut self.chunks, bytes)
415 }
416
417 #[inline]
418 fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
419 Encoder::rewrite(&mut self.chunks, pos, bytes)
420 }
421}
422
423impl HandleEncoder for Buffer {
424 fn push_handle(&mut self, handle: NullableHandle) -> Result<(), EncodeError> {
425 self.handles.push(handle);
426 Ok(())
427 }
428
429 fn handles_pushed(&self) -> usize {
430 self.handles.len()
431 }
432}
433
434pub struct RecvBuffer {
436 buffer: Buffer,
437 chunks_taken: usize,
438 handles_taken: usize,
439}
440
441impl RecvBuffer {
442 pub fn new(buffer: Buffer) -> Self {
444 Self { buffer, chunks_taken: 0, handles_taken: 0 }
445 }
446}
447
448unsafe impl Decoder for RecvBuffer {
449 fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
450 if count > self.buffer.chunks.len() - self.chunks_taken {
451 return Err(DecodeError::InsufficientData);
452 }
453
454 let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
455 self.chunks_taken += count;
456
457 unsafe { Ok(NonNull::new_unchecked(chunks)) }
458 }
459
460 fn commit(&mut self) {
461 for handle in &mut self.buffer.handles[0..self.handles_taken] {
462 let _ = replace(handle, NullableHandle::invalid()).into_raw();
464 }
465 }
466
467 fn finish(&self) -> Result<(), DecodeError> {
468 if self.chunks_taken != self.buffer.chunks.len() {
469 return Err(DecodeError::ExtraBytes {
470 num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
471 });
472 }
473
474 if self.handles_taken != self.buffer.handles.len() {
475 return Err(DecodeError::ExtraHandles {
476 num_extra: self.buffer.handles.len() - self.handles_taken,
477 });
478 }
479
480 Ok(())
481 }
482}
483
484impl InternalHandleDecoder for RecvBuffer {
485 fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
486 if count > self.buffer.handles.len() - self.handles_taken {
487 return Err(DecodeError::InsufficientHandles);
488 }
489
490 for i in self.handles_taken..self.handles_taken + count {
491 let handle = replace(&mut self.buffer.handles[i], NullableHandle::invalid());
492 drop(handle);
493 }
494 self.handles_taken += count;
495
496 Ok(())
497 }
498
499 fn __internal_handles_remaining(&self) -> usize {
500 self.buffer.handles.len() - self.handles_taken
501 }
502}
503
504impl HandleDecoder for RecvBuffer {
505 fn take_raw_handle(&mut self) -> Result<zx_handle_t, DecodeError> {
506 if self.handles_taken >= self.buffer.handles.len() {
507 return Err(DecodeError::InsufficientHandles);
508 }
509
510 let handle = self.buffer.handles[self.handles_taken].raw_handle();
511 self.handles_taken += 1;
512
513 Ok(handle)
514 }
515
516 fn handles_remaining(&mut self) -> usize {
517 self.buffer.handles.len() - self.handles_taken
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use fdf::CurrentDispatcher;
525 use fdf_env::test::spawn_in_driver;
526 use fidl_next::{ClientDispatcher, ClientEnd, IgnoreEvents};
527 use fidl_next_fuchsia_examples_gizmo::Device;
528
529 #[fuchsia::test]
530 async fn wait_pending_at_dispatcher_shutdown() {
531 spawn_in_driver("driver fidl server", async {
532 let (_server_chan, client_chan) = Channel::create();
533 let client_end: ClientEnd<Device, _> = ClientEnd::<Device, _>::from_untyped(
534 AsyncChannel::new_on_dispatcher(CurrentDispatcher, client_chan),
535 );
536 let client_dispatcher = ClientDispatcher::new(client_end);
537 let _client = client_dispatcher.client();
538 CurrentDispatcher
539 .spawn(async {
540 println!(
541 "client task finished: {:?}",
542 client_dispatcher.run(IgnoreEvents).await.map(|_| ())
543 );
544 })
545 .unwrap();
546 (_server_chan, _client)
547 });
548 }
549}