libasync_fidl/
lib.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for using FIDL with the libasync C API
6#![deny(unsafe_op_in_unsafe_fn, missing_docs)]
7
8use std::mem::replace;
9use std::pin::Pin;
10use std::ptr::NonNull;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::task::{Context, Poll};
14
15use fidl_next::decoder::InternalHandleDecoder;
16use fidl_next::encoder::InternalHandleEncoder;
17use fidl_next::fuchsia::{HandleDecoder, HandleEncoder};
18use fidl_next::protocol::NonBlockingTransport;
19use fidl_next::{
20    CHUNK_SIZE, Chunk, ClientEnd, DecodeError, Decoder, EncodeError, Encoder, Executor,
21    HasExecutor, ServerEnd, Transport,
22};
23use futures::task::AtomicWaker;
24use libasync::callback_state::CallbackSharedState;
25use libasync::{OnDispatcher, Task};
26use libasync_sys::{async_begin_wait, async_dispatcher, async_wait};
27use zx::sys::{
28    ZX_CHANNEL_PEER_CLOSED, ZX_CHANNEL_READABLE, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_CANCELED,
29    ZX_ERR_PEER_CLOSED, ZX_ERR_SHOULD_WAIT, ZX_OK, zx_channel_read, zx_channel_write, zx_handle_t,
30    zx_packet_signal_t, zx_status_t,
31};
32use zx::{AsHandleRef, Channel, HandleBased, NullableHandle, Status};
33
34/// A fidl-compatible channel that uses a [`libasync`] dispatcher.
35#[derive(Debug, PartialEq)]
36pub struct AsyncChannel<D> {
37    dispatcher: D,
38    channel: Arc<Channel>,
39}
40
41impl<D> AsyncChannel<D> {
42    /// Creates an async channel bound to the dispatcher `d` that can be used with fidl bindings.
43    pub fn new_on_dispatcher(dispatcher: D, channel: Channel) -> Self {
44        Self { dispatcher, channel: Arc::new(channel) }
45    }
46
47    /// A shortcut for creating a [`fidl_next`] compatible [`ClientEnd`] out of a
48    /// [`Channel`] and dispatcher.
49    pub fn client_from_zx_channel_on_dispatcher<P>(
50        from: ClientEnd<P, Channel>,
51        dispatcher: D,
52    ) -> ClientEnd<P, Self> {
53        let channel = from.into_untyped();
54        ClientEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
55    }
56
57    /// A shortcut for creating a [`fidl_next`] compatible [`ServerEnd`] out of a
58    /// [`Channel`] and dispatcher.
59    pub fn server_from_zx_channel_on_dispatcher<P>(
60        from: ServerEnd<P, Channel>,
61        dispatcher: D,
62    ) -> ServerEnd<P, Self> {
63        let channel = from.into_untyped();
64        ServerEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
65    }
66}
67
68impl<D: Default> AsyncChannel<D> {
69    /// Creates an async channel bound to the [`Default`] instance of dispatcher `D` that can
70    /// be used with fidl bindings.
71    pub fn new(channel: Channel) -> Self {
72        Self::new_on_dispatcher(D::default(), channel)
73    }
74
75    /// A shortcut for creating a [`fidl_next`] compatible [`ClientEnd`] out of a
76    /// [`Channel`].
77    pub fn client_from_zx_channel<P>(from: ClientEnd<P, Channel>) -> ClientEnd<P, Self> {
78        Self::client_from_zx_channel_on_dispatcher(from, D::default())
79    }
80
81    /// A shortcut for creating a [`fidl_next`] compatible [`ServerEnd`] out of a
82    /// [`Channel`].
83    pub fn server_from_zx_channel<P>(from: ServerEnd<P, Channel>) -> ServerEnd<P, Self> {
84        Self::server_from_zx_channel_on_dispatcher(from, D::default())
85    }
86}
87
88impl<D: OnDispatcher> Transport for AsyncChannel<D> {
89    type Error = Status;
90    type Shared = Arc<Channel>;
91    type Exclusive = Exclusive<D>;
92    type SendBuffer = Buffer;
93    type SendFutureState = SendFutureState;
94    type RecvFutureState = RecvFutureState;
95    type RecvBuffer = RecvBuffer;
96
97    fn split(self) -> (Self::Shared, Self::Exclusive) {
98        let channel = self.channel;
99        let object = channel.raw_handle();
100        (
101            channel.clone(),
102            Exclusive {
103                dispatcher: self.dispatcher,
104                callback_state: CallbackState::new(
105                    async_wait {
106                        handler: Some(RecvCallbackState::handler),
107                        object,
108                        trigger: ZX_CHANNEL_PEER_CLOSED | ZX_CHANNEL_READABLE,
109                        ..Default::default()
110                    },
111                    RecvCallbackState {
112                        _channel: channel,
113                        canceled: AtomicBool::new(false),
114                        waker: AtomicWaker::new(),
115                    },
116                ),
117            },
118        )
119    }
120
121    fn acquire(_shared: &Self::Shared) -> Self::SendBuffer {
122        Buffer::new()
123    }
124
125    fn begin_send(_: &Self::Shared, buffer: Self::SendBuffer) -> Self::SendFutureState {
126        SendFutureState { buffer }
127    }
128
129    fn poll_send(
130        future_state: Pin<&mut Self::SendFutureState>,
131        _: &mut Context<'_>,
132        shared: &Self::Shared,
133    ) -> Poll<Result<(), Option<Self::Error>>> {
134        Poll::Ready(Self::send_immediately(future_state.get_mut(), shared))
135    }
136
137    fn begin_recv(
138        _shared: &Self::Shared,
139        exclusive: &mut Self::Exclusive,
140    ) -> Self::RecvFutureState {
141        RecvFutureState {
142            buffer: Some(Buffer::new()),
143            callback_state: Arc::downgrade(&exclusive.callback_state),
144        }
145    }
146
147    fn poll_recv(
148        mut future_state: Pin<&mut Self::RecvFutureState>,
149        cx: &mut Context<'_>,
150        shared: &Self::Shared,
151        exclusive: &mut Self::Exclusive,
152    ) -> Poll<Result<Self::RecvBuffer, Option<Self::Error>>> {
153        let buffer = future_state.buffer.as_mut().unwrap();
154
155        let mut actual_bytes = 0;
156        let mut actual_handles = 0;
157
158        loop {
159            let result = unsafe {
160                zx_channel_read(
161                    shared.raw_handle(),
162                    0,
163                    buffer.chunks.as_mut_ptr().cast(),
164                    buffer.handles.as_mut_ptr().cast(),
165                    (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
166                    buffer.handles.capacity() as u32,
167                    &mut actual_bytes,
168                    &mut actual_handles,
169                )
170            };
171
172            match result {
173                ZX_OK => {
174                    unsafe {
175                        buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
176                        buffer.handles.set_len(actual_handles as usize);
177                    }
178                    return Poll::Ready(Ok(RecvBuffer {
179                        buffer: future_state.buffer.take().unwrap(),
180                        chunks_taken: 0,
181                        handles_taken: 0,
182                    }));
183                }
184                ZX_ERR_PEER_CLOSED => return Poll::Ready(Err(None)),
185                ZX_ERR_BUFFER_TOO_SMALL => {
186                    let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
187                    buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
188                    buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
189                }
190                ZX_ERR_SHOULD_WAIT => {
191                    exclusive.wait_readable(cx)?;
192                    return Poll::Pending;
193                }
194                raw => return Poll::Ready(Err(Some(Status::from_raw(raw)))),
195            }
196        }
197    }
198}
199
200impl<D: OnDispatcher> NonBlockingTransport for AsyncChannel<D> {
201    fn send_immediately(
202        future_state: &mut Self::SendFutureState,
203        shared: &Self::Shared,
204    ) -> Result<(), Option<Self::Error>> {
205        let result = unsafe {
206            zx_channel_write(
207                shared.raw_handle(),
208                0,
209                future_state.buffer.chunks.as_ptr().cast::<u8>(),
210                (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
211                future_state.buffer.handles.as_ptr().cast(),
212                future_state.buffer.handles.len() as u32,
213            )
214        };
215
216        match result {
217            ZX_OK => {
218                // Handles were written to the channel, so we must not drop them.
219                unsafe {
220                    future_state.buffer.handles.set_len(0);
221                }
222                Ok(())
223            }
224            ZX_ERR_PEER_CLOSED => Err(None),
225            _ => Err(Some(Status::from_raw(result))),
226        }
227    }
228}
229
230/// A wrapper around a dispatcher reference object that can be used with the [`fidl_next`] bindings
231/// to spawn client and server dispatchers on a driver runtime provided async dispatcher.
232pub struct FidlExecutor<D>(D);
233
234impl<D> std::ops::Deref for FidlExecutor<D> {
235    type Target = D;
236    fn deref(&self) -> &Self::Target {
237        &self.0
238    }
239}
240
241impl<D> From<D> for FidlExecutor<D> {
242    fn from(value: D) -> Self {
243        FidlExecutor(value)
244    }
245}
246
247impl<D: OnDispatcher + 'static> Executor for FidlExecutor<D> {
248    type Task<T: 'static> = Task<T>;
249
250    fn spawn<F>(&self, future: F) -> Self::Task<F::Output>
251    where
252        F: Future + Send + 'static,
253        F::Output: Send + 'static,
254    {
255        self.0.compute(future)
256    }
257
258    fn detach<T: 'static>(&self, task: Self::Task<T>) {
259        task.detach()
260    }
261}
262
263impl<D: OnDispatcher> fidl_next::RunsTransport<AsyncChannel<D>> for FidlExecutor<D> {}
264
265impl<D: OnDispatcher + 'static> HasExecutor for AsyncChannel<D> {
266    type Executor = FidlExecutor<D>;
267
268    fn executor(&self) -> Self::Executor {
269        FidlExecutor(self.dispatcher.clone())
270    }
271}
272
273type CallbackState = CallbackSharedState<async_wait, RecvCallbackState>;
274
275#[doc(hidden)] // Internal implementation detail of fidl_next api
276pub struct Exclusive<D> {
277    callback_state: Arc<CallbackState>,
278    dispatcher: D,
279}
280
281impl<D: OnDispatcher> Exclusive<D> {
282    fn wait_readable(&mut self, cx: &Context<'_>) -> Result<(), Status> {
283        self.callback_state.waker.register(cx.waker());
284        if self.callback_state.canceled.load(Ordering::Relaxed) {
285            // the dispatcher has shut down so we can't wait again
286            return Err(Status::CANCELED);
287        }
288
289        if Arc::strong_count(&self.callback_state) > 1 {
290            // the callback is holding a strong reference to this so we're already waiting
291            // (or maybe in the process of cancelling) for a callback, so just return.
292            return Ok(());
293        }
294        self.dispatcher.on_maybe_dispatcher(|dispatcher| {
295            let callback_state_ptr = CallbackState::make_raw_ptr(self.callback_state.clone());
296            // SAFETY: fill this in
297            Status::ok(unsafe { async_begin_wait(dispatcher.inner().as_ptr(), callback_state_ptr) })
298                .inspect_err(|_| {
299                    // SAFETY: The wait failed so we have an outstanding reference to the callback
300                    // state that needs to be freed since the callback will not be called.
301                    unsafe { CallbackState::release_raw_ptr(callback_state_ptr) };
302                })
303        })
304    }
305}
306
307/// State shared between the callback and the future.
308struct RecvCallbackState {
309    _channel: Arc<Channel>,
310    canceled: AtomicBool,
311    waker: AtomicWaker,
312}
313
314impl RecvCallbackState {
315    unsafe extern "C" fn handler(
316        _dispatcher: *mut async_dispatcher,
317        callback_state_ptr: *mut async_wait,
318        status: zx_status_t,
319        _packet: *const zx_packet_signal_t,
320    ) {
321        debug_assert!(
322            status == ZX_OK || status == ZX_ERR_CANCELED,
323            "task callback called with status other than ok or canceled"
324        );
325        // SAFETY: This callback's copy of the `async_task` object was refcounted for when we
326        // started the wait.
327        let state = unsafe { CallbackState::from_raw_ptr(callback_state_ptr) };
328        if status == ZX_ERR_CANCELED {
329            state.canceled.store(true, Ordering::Relaxed);
330        }
331        state.waker.wake();
332    }
333}
334
335/// The state for a channel recv future.
336pub struct RecvFutureState {
337    buffer: Option<Buffer>,
338    callback_state: Weak<CallbackState>,
339}
340
341impl Drop for RecvFutureState {
342    fn drop(&mut self) {
343        let Some(state) = self.callback_state.upgrade() else { return };
344        // todo: properly implement cancelation
345        state.waker.wake();
346    }
347}
348
349/// The state for a channel send future.
350pub struct SendFutureState {
351    buffer: Buffer,
352}
353
354/// A channel buffer.
355#[derive(Default)]
356pub struct Buffer {
357    handles: Vec<NullableHandle>,
358    chunks: Vec<Chunk>,
359}
360
361impl Buffer {
362    /// New buffer.
363    pub fn new() -> Self {
364        Self::default()
365    }
366
367    /// Retrieve the handles.
368    pub fn handles(&self) -> &[NullableHandle] {
369        &self.handles
370    }
371
372    /// Retrieve the bytes.
373    pub fn bytes(&self) -> Vec<u8> {
374        self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
375    }
376
377    /// Make a buffer out of handles and chunks.
378    pub fn from_raw(handles: Vec<NullableHandle>, chunks: Vec<Chunk>) -> Self {
379        Self { handles, chunks }
380    }
381
382    /// Make a buffer out of handles and bytes. The bytes will be copied.
383    pub fn from_raw_bytes(handles: Vec<NullableHandle>, bytes: impl AsRef<[u8]>) -> Self {
384        let bytes = bytes.as_ref();
385        assert!(bytes.len() % CHUNK_SIZE == 0);
386        let chunks = bytes
387            .chunks_exact(CHUNK_SIZE)
388            .map(|c| fidl_next::WireU64(u64::from_le_bytes(c.try_into().unwrap())))
389            .collect();
390        Self::from_raw(handles, chunks)
391    }
392}
393
394impl InternalHandleEncoder for Buffer {
395    #[inline]
396    fn __internal_handle_count(&self) -> usize {
397        self.handles.len()
398    }
399}
400
401impl Encoder for Buffer {
402    #[inline]
403    fn bytes_written(&self) -> usize {
404        Encoder::bytes_written(&self.chunks)
405    }
406
407    #[inline]
408    fn write_zeroes(&mut self, len: usize) {
409        Encoder::write_zeroes(&mut self.chunks, len)
410    }
411
412    #[inline]
413    fn write(&mut self, bytes: &[u8]) {
414        Encoder::write(&mut self.chunks, bytes)
415    }
416
417    #[inline]
418    fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
419        Encoder::rewrite(&mut self.chunks, pos, bytes)
420    }
421}
422
423impl HandleEncoder for Buffer {
424    fn push_handle(&mut self, handle: NullableHandle) -> Result<(), EncodeError> {
425        self.handles.push(handle);
426        Ok(())
427    }
428
429    fn handles_pushed(&self) -> usize {
430        self.handles.len()
431    }
432}
433
434/// A channel receive buffer.
435pub struct RecvBuffer {
436    buffer: Buffer,
437    chunks_taken: usize,
438    handles_taken: usize,
439}
440
441impl RecvBuffer {
442    /// Create a new receive buffer from a buffer.
443    pub fn new(buffer: Buffer) -> Self {
444        Self { buffer, chunks_taken: 0, handles_taken: 0 }
445    }
446}
447
448unsafe impl Decoder for RecvBuffer {
449    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
450        if count > self.buffer.chunks.len() - self.chunks_taken {
451            return Err(DecodeError::InsufficientData);
452        }
453
454        let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
455        self.chunks_taken += count;
456
457        unsafe { Ok(NonNull::new_unchecked(chunks)) }
458    }
459
460    fn commit(&mut self) {
461        for handle in &mut self.buffer.handles[0..self.handles_taken] {
462            // This handle was taken. To commit the current changes, we need to forget it.
463            let _ = replace(handle, NullableHandle::invalid()).into_raw();
464        }
465    }
466
467    fn finish(&self) -> Result<(), DecodeError> {
468        if self.chunks_taken != self.buffer.chunks.len() {
469            return Err(DecodeError::ExtraBytes {
470                num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
471            });
472        }
473
474        if self.handles_taken != self.buffer.handles.len() {
475            return Err(DecodeError::ExtraHandles {
476                num_extra: self.buffer.handles.len() - self.handles_taken,
477            });
478        }
479
480        Ok(())
481    }
482}
483
484impl InternalHandleDecoder for RecvBuffer {
485    fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
486        if count > self.buffer.handles.len() - self.handles_taken {
487            return Err(DecodeError::InsufficientHandles);
488        }
489
490        for i in self.handles_taken..self.handles_taken + count {
491            let handle = replace(&mut self.buffer.handles[i], NullableHandle::invalid());
492            drop(handle);
493        }
494        self.handles_taken += count;
495
496        Ok(())
497    }
498
499    fn __internal_handles_remaining(&self) -> usize {
500        self.buffer.handles.len() - self.handles_taken
501    }
502}
503
504impl HandleDecoder for RecvBuffer {
505    fn take_raw_handle(&mut self) -> Result<zx_handle_t, DecodeError> {
506        if self.handles_taken >= self.buffer.handles.len() {
507            return Err(DecodeError::InsufficientHandles);
508        }
509
510        let handle = self.buffer.handles[self.handles_taken].raw_handle();
511        self.handles_taken += 1;
512
513        Ok(handle)
514    }
515
516    fn handles_remaining(&mut self) -> usize {
517        self.buffer.handles.len() - self.handles_taken
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use fdf::CurrentDispatcher;
525    use fdf_env::test::spawn_in_driver;
526    use fidl_next::{ClientDispatcher, ClientEnd, IgnoreEvents};
527    use fidl_next_fuchsia_examples_gizmo::Device;
528
529    #[fuchsia::test]
530    async fn wait_pending_at_dispatcher_shutdown() {
531        spawn_in_driver("driver fidl server", async {
532            let (_server_chan, client_chan) = Channel::create();
533            let client_end: ClientEnd<Device, _> = ClientEnd::<Device, _>::from_untyped(
534                AsyncChannel::new_on_dispatcher(CurrentDispatcher, client_chan),
535            );
536            let client_dispatcher = ClientDispatcher::new(client_end);
537            let _client = client_dispatcher.client();
538            CurrentDispatcher
539                .spawn(async {
540                    println!(
541                        "client task finished: {:?}",
542                        client_dispatcher.run(IgnoreEvents).await.map(|_| ())
543                    );
544                })
545                .unwrap();
546            (_server_chan, _client)
547        });
548    }
549}