libasync_fidl/
lib.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for using FIDL with the libasync C API
6#![deny(unsafe_op_in_unsafe_fn, missing_docs)]
7
8use std::mem::replace;
9use std::pin::Pin;
10use std::ptr::NonNull;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::task::{Context, Poll};
14
15use fidl_next::decoder::InternalHandleDecoder;
16use fidl_next::encoder::InternalHandleEncoder;
17use fidl_next::fuchsia::{HandleDecoder, HandleEncoder};
18use fidl_next::protocol::NonBlockingTransport;
19use fidl_next::{
20    CHUNK_SIZE, Chunk, ClientEnd, DecodeError, Decoder, EncodeError, Encoder, Executor,
21    HasExecutor, ServerEnd, Transport,
22};
23use futures::task::AtomicWaker;
24use libasync::callback_state::CallbackSharedState;
25use libasync::{JoinHandle, OnDispatcher};
26use libasync_sys::{async_begin_wait, async_dispatcher, async_wait};
27use zx::sys::{
28    ZX_CHANNEL_PEER_CLOSED, ZX_CHANNEL_READABLE, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_CANCELED,
29    ZX_ERR_PEER_CLOSED, ZX_ERR_SHOULD_WAIT, ZX_OK, zx_channel_read, zx_channel_write, zx_handle_t,
30    zx_packet_signal_t, zx_status_t,
31};
32use zx::{AsHandleRef, Channel, NullableHandle, Status};
33
34/// A fidl-compatible channel that uses a [`libasync`] dispatcher.
35#[derive(Debug, PartialEq)]
36pub struct AsyncChannel<D> {
37    dispatcher: D,
38    channel: Arc<Channel>,
39}
40
41impl<D> AsyncChannel<D> {
42    /// Creates an async channel bound to the dispatcher `d` that can be used with fidl bindings.
43    pub fn new_on_dispatcher(dispatcher: D, channel: Channel) -> Self {
44        Self { dispatcher, channel: Arc::new(channel) }
45    }
46
47    /// A shortcut for creating a [`fidl_next`] compatible [`ClientEnd`] out of a
48    /// [`Channel`] and dispatcher.
49    pub fn client_from_zx_channel_on_dispatcher<P>(
50        from: ClientEnd<P, Channel>,
51        dispatcher: D,
52    ) -> ClientEnd<P, Self> {
53        let channel = from.into_untyped();
54        ClientEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
55    }
56
57    /// A shortcut for creating a [`fidl_next`] compatible [`ServerEnd`] out of a
58    /// [`Channel`] and dispatcher.
59    pub fn server_from_zx_channel_on_dispatcher<P>(
60        from: ServerEnd<P, Channel>,
61        dispatcher: D,
62    ) -> ServerEnd<P, Self> {
63        let channel = from.into_untyped();
64        ServerEnd::from_untyped(Self { dispatcher, channel: Arc::new(channel) })
65    }
66}
67
68impl<D: Default> AsyncChannel<D> {
69    /// Creates an async channel bound to the [`Default`] instance of dispatcher `D` that can
70    /// be used with fidl bindings.
71    pub fn new(channel: Channel) -> Self {
72        Self::new_on_dispatcher(D::default(), channel)
73    }
74
75    /// A shortcut for creating a [`fidl_next`] compatible [`ClientEnd`] out of a
76    /// [`Channel`].
77    pub fn client_from_zx_channel<P>(from: ClientEnd<P, Channel>) -> ClientEnd<P, Self> {
78        Self::client_from_zx_channel_on_dispatcher(from, D::default())
79    }
80
81    /// A shortcut for creating a [`fidl_next`] compatible [`ServerEnd`] out of a
82    /// [`Channel`].
83    pub fn server_from_zx_channel<P>(from: ServerEnd<P, Channel>) -> ServerEnd<P, Self> {
84        Self::server_from_zx_channel_on_dispatcher(from, D::default())
85    }
86}
87
88impl<D: OnDispatcher> Transport for AsyncChannel<D> {
89    type Error = Status;
90    type Shared = Arc<Channel>;
91    type Exclusive = Exclusive<D>;
92    type SendBuffer = Buffer;
93    type SendFutureState = SendFutureState;
94    type RecvFutureState = RecvFutureState;
95    type RecvBuffer = RecvBuffer;
96
97    fn split(self) -> (Self::Shared, Self::Exclusive) {
98        let channel = self.channel;
99        let object = channel.raw_handle();
100        (
101            channel.clone(),
102            Exclusive {
103                dispatcher: self.dispatcher,
104                callback_state: CallbackState::new(
105                    async_wait {
106                        handler: Some(RecvCallbackState::handler),
107                        object,
108                        trigger: ZX_CHANNEL_PEER_CLOSED | ZX_CHANNEL_READABLE,
109                        ..Default::default()
110                    },
111                    RecvCallbackState {
112                        _channel: channel,
113                        canceled: AtomicBool::new(false),
114                        waker: AtomicWaker::new(),
115                    },
116                ),
117            },
118        )
119    }
120
121    fn acquire(_shared: &Self::Shared) -> Self::SendBuffer {
122        Buffer::new()
123    }
124
125    fn begin_send(_: &Self::Shared, buffer: Self::SendBuffer) -> Self::SendFutureState {
126        SendFutureState { buffer }
127    }
128
129    fn poll_send(
130        future_state: Pin<&mut Self::SendFutureState>,
131        _: &mut Context<'_>,
132        shared: &Self::Shared,
133    ) -> Poll<Result<(), Option<Self::Error>>> {
134        Poll::Ready(Self::send_immediately(future_state.get_mut(), shared))
135    }
136
137    fn begin_recv(
138        _shared: &Self::Shared,
139        exclusive: &mut Self::Exclusive,
140    ) -> Self::RecvFutureState {
141        RecvFutureState {
142            buffer: Some(Buffer::new()),
143            callback_state: Arc::downgrade(&exclusive.callback_state),
144        }
145    }
146
147    fn poll_recv(
148        mut future_state: Pin<&mut Self::RecvFutureState>,
149        cx: &mut Context<'_>,
150        shared: &Self::Shared,
151        exclusive: &mut Self::Exclusive,
152    ) -> Poll<Result<Self::RecvBuffer, Option<Self::Error>>> {
153        let buffer = future_state.buffer.as_mut().unwrap();
154
155        let mut actual_bytes = 0;
156        let mut actual_handles = 0;
157
158        loop {
159            let result = unsafe {
160                zx_channel_read(
161                    shared.raw_handle(),
162                    0,
163                    buffer.chunks.as_mut_ptr().cast(),
164                    buffer.handles.as_mut_ptr().cast(),
165                    (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
166                    buffer.handles.capacity() as u32,
167                    &mut actual_bytes,
168                    &mut actual_handles,
169                )
170            };
171
172            match result {
173                ZX_OK => {
174                    unsafe {
175                        buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
176                        buffer.handles.set_len(actual_handles as usize);
177                    }
178                    return Poll::Ready(Ok(RecvBuffer {
179                        buffer: future_state.buffer.take().unwrap(),
180                        chunks_taken: 0,
181                        handles_taken: 0,
182                    }));
183                }
184                ZX_ERR_PEER_CLOSED => return Poll::Ready(Err(None)),
185                ZX_ERR_BUFFER_TOO_SMALL => {
186                    let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
187                    buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
188                    buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
189                }
190                ZX_ERR_SHOULD_WAIT => {
191                    exclusive.wait_readable(cx)?;
192                    return Poll::Pending;
193                }
194                raw => return Poll::Ready(Err(Some(Status::from_raw(raw)))),
195            }
196        }
197    }
198}
199
200impl<D: OnDispatcher> NonBlockingTransport for AsyncChannel<D> {
201    fn send_immediately(
202        future_state: &mut Self::SendFutureState,
203        shared: &Self::Shared,
204    ) -> Result<(), Option<Self::Error>> {
205        let result = unsafe {
206            zx_channel_write(
207                shared.raw_handle(),
208                0,
209                future_state.buffer.chunks.as_ptr().cast::<u8>(),
210                (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
211                future_state.buffer.handles.as_ptr().cast(),
212                future_state.buffer.handles.len() as u32,
213            )
214        };
215
216        match result {
217            ZX_OK => {
218                // Handles were written to the channel, so we must not drop them.
219                unsafe {
220                    future_state.buffer.handles.set_len(0);
221                }
222                Ok(())
223            }
224            ZX_ERR_PEER_CLOSED => Err(None),
225            _ => Err(Some(Status::from_raw(result))),
226        }
227    }
228}
229
230/// A wrapper around a dispatcher reference object that can be used with the [`fidl_next`] bindings
231/// to spawn client and server dispatchers on a driver runtime provided async dispatcher.
232pub struct FidlExecutor<D>(D);
233
234impl<D> std::ops::Deref for FidlExecutor<D> {
235    type Target = D;
236    fn deref(&self) -> &Self::Target {
237        &self.0
238    }
239}
240
241impl<D> From<D> for FidlExecutor<D> {
242    fn from(value: D) -> Self {
243        FidlExecutor(value)
244    }
245}
246
247impl<D: OnDispatcher + 'static> Executor for FidlExecutor<D> {
248    type JoinHandle<T: 'static> = JoinHandle<T>;
249
250    fn spawn<F>(&self, future: F) -> Self::JoinHandle<F::Output>
251    where
252        F: Future + Send + 'static,
253        F::Output: Send + 'static,
254    {
255        self.0.compute(future).detach_on_drop()
256    }
257}
258
259impl<D: OnDispatcher> fidl_next::RunsTransport<AsyncChannel<D>> for FidlExecutor<D> {}
260
261impl<D: OnDispatcher + 'static> HasExecutor for AsyncChannel<D> {
262    type Executor = FidlExecutor<D>;
263
264    fn executor(&self) -> Self::Executor {
265        FidlExecutor(self.dispatcher.clone())
266    }
267}
268
269type CallbackState = CallbackSharedState<async_wait, RecvCallbackState>;
270
271#[doc(hidden)] // Internal implementation detail of fidl_next api
272pub struct Exclusive<D> {
273    callback_state: Arc<CallbackState>,
274    dispatcher: D,
275}
276
277impl<D: OnDispatcher> Exclusive<D> {
278    fn wait_readable(&mut self, cx: &Context<'_>) -> Result<(), Status> {
279        self.callback_state.waker.register(cx.waker());
280        if self.callback_state.canceled.load(Ordering::Relaxed) {
281            // the dispatcher has shut down so we can't wait again
282            return Err(Status::CANCELED);
283        }
284
285        if Arc::strong_count(&self.callback_state) > 1 {
286            // the callback is holding a strong reference to this so we're already waiting
287            // (or maybe in the process of cancelling) for a callback, so just return.
288            return Ok(());
289        }
290        self.dispatcher.on_maybe_dispatcher(|dispatcher| {
291            let callback_state_ptr = CallbackState::make_raw_ptr(self.callback_state.clone());
292            // SAFETY: fill this in
293            Status::ok(unsafe { async_begin_wait(dispatcher.inner().as_ptr(), callback_state_ptr) })
294                .inspect_err(|_| {
295                    // SAFETY: The wait failed so we have an outstanding reference to the callback
296                    // state that needs to be freed since the callback will not be called.
297                    unsafe { CallbackState::release_raw_ptr(callback_state_ptr) };
298                })
299        })
300    }
301}
302
303/// State shared between the callback and the future.
304struct RecvCallbackState {
305    _channel: Arc<Channel>,
306    canceled: AtomicBool,
307    waker: AtomicWaker,
308}
309
310impl RecvCallbackState {
311    unsafe extern "C" fn handler(
312        _dispatcher: *mut async_dispatcher,
313        callback_state_ptr: *mut async_wait,
314        status: zx_status_t,
315        _packet: *const zx_packet_signal_t,
316    ) {
317        debug_assert!(
318            status == ZX_OK || status == ZX_ERR_CANCELED,
319            "task callback called with status other than ok or canceled"
320        );
321        // SAFETY: This callback's copy of the `async_task` object was refcounted for when we
322        // started the wait.
323        let state = unsafe { CallbackState::from_raw_ptr(callback_state_ptr) };
324        if status == ZX_ERR_CANCELED {
325            state.canceled.store(true, Ordering::Relaxed);
326        }
327        state.waker.wake();
328    }
329}
330
331/// The state for a channel recv future.
332pub struct RecvFutureState {
333    buffer: Option<Buffer>,
334    callback_state: Weak<CallbackState>,
335}
336
337impl Drop for RecvFutureState {
338    fn drop(&mut self) {
339        let Some(state) = self.callback_state.upgrade() else { return };
340        // todo: properly implement cancelation
341        state.waker.wake();
342    }
343}
344
345/// The state for a channel send future.
346pub struct SendFutureState {
347    buffer: Buffer,
348}
349
350/// A channel buffer.
351#[derive(Default)]
352pub struct Buffer {
353    handles: Vec<NullableHandle>,
354    chunks: Vec<Chunk>,
355}
356
357impl Buffer {
358    /// New buffer.
359    pub fn new() -> Self {
360        Self::default()
361    }
362
363    /// Retrieve the handles.
364    pub fn handles(&self) -> &[NullableHandle] {
365        &self.handles
366    }
367
368    /// Retrieve the bytes.
369    pub fn bytes(&self) -> Vec<u8> {
370        self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
371    }
372
373    /// Make a buffer out of handles and chunks.
374    pub fn from_raw(handles: Vec<NullableHandle>, chunks: Vec<Chunk>) -> Self {
375        Self { handles, chunks }
376    }
377
378    /// Make a buffer out of handles and bytes. The bytes will be copied.
379    pub fn from_raw_bytes(handles: Vec<NullableHandle>, bytes: impl AsRef<[u8]>) -> Self {
380        let bytes = bytes.as_ref();
381        assert!(bytes.len() % CHUNK_SIZE == 0);
382        let chunks = bytes
383            .chunks_exact(CHUNK_SIZE)
384            .map(|c| fidl_next::WireU64(u64::from_le_bytes(c.try_into().unwrap())))
385            .collect();
386        Self::from_raw(handles, chunks)
387    }
388}
389
390impl InternalHandleEncoder for Buffer {
391    #[inline]
392    fn __internal_handle_count(&self) -> usize {
393        self.handles.len()
394    }
395}
396
397impl Encoder for Buffer {
398    #[inline]
399    fn bytes_written(&self) -> usize {
400        Encoder::bytes_written(&self.chunks)
401    }
402
403    #[inline]
404    fn write_zeroes(&mut self, len: usize) {
405        Encoder::write_zeroes(&mut self.chunks, len)
406    }
407
408    #[inline]
409    fn write(&mut self, bytes: &[u8]) {
410        Encoder::write(&mut self.chunks, bytes)
411    }
412
413    #[inline]
414    fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
415        Encoder::rewrite(&mut self.chunks, pos, bytes)
416    }
417}
418
419impl HandleEncoder for Buffer {
420    fn push_handle(&mut self, handle: NullableHandle) -> Result<(), EncodeError> {
421        self.handles.push(handle);
422        Ok(())
423    }
424
425    fn handles_pushed(&self) -> usize {
426        self.handles.len()
427    }
428}
429
430/// A channel receive buffer.
431pub struct RecvBuffer {
432    buffer: Buffer,
433    chunks_taken: usize,
434    handles_taken: usize,
435}
436
437impl RecvBuffer {
438    /// Create a new receive buffer from a buffer.
439    pub fn new(buffer: Buffer) -> Self {
440        Self { buffer, chunks_taken: 0, handles_taken: 0 }
441    }
442}
443
444unsafe impl Decoder for RecvBuffer {
445    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
446        if count > self.buffer.chunks.len() - self.chunks_taken {
447            return Err(DecodeError::InsufficientData);
448        }
449
450        let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
451        self.chunks_taken += count;
452
453        unsafe { Ok(NonNull::new_unchecked(chunks)) }
454    }
455
456    fn commit(&mut self) {
457        for handle in &mut self.buffer.handles[0..self.handles_taken] {
458            // This handle was taken. To commit the current changes, we need to forget it.
459            let _ = replace(handle, NullableHandle::invalid()).into_raw();
460        }
461    }
462
463    fn finish(&self) -> Result<(), DecodeError> {
464        if self.chunks_taken != self.buffer.chunks.len() {
465            return Err(DecodeError::ExtraBytes {
466                num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
467            });
468        }
469
470        if self.handles_taken != self.buffer.handles.len() {
471            return Err(DecodeError::ExtraHandles {
472                num_extra: self.buffer.handles.len() - self.handles_taken,
473            });
474        }
475
476        Ok(())
477    }
478}
479
480impl InternalHandleDecoder for RecvBuffer {
481    fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
482        if count > self.buffer.handles.len() - self.handles_taken {
483            return Err(DecodeError::InsufficientHandles);
484        }
485
486        for i in self.handles_taken..self.handles_taken + count {
487            let handle = replace(&mut self.buffer.handles[i], NullableHandle::invalid());
488            drop(handle);
489        }
490        self.handles_taken += count;
491
492        Ok(())
493    }
494
495    fn __internal_handles_remaining(&self) -> usize {
496        self.buffer.handles.len() - self.handles_taken
497    }
498}
499
500impl HandleDecoder for RecvBuffer {
501    fn take_raw_handle(&mut self) -> Result<zx_handle_t, DecodeError> {
502        if self.handles_taken >= self.buffer.handles.len() {
503            return Err(DecodeError::InsufficientHandles);
504        }
505
506        let handle = self.buffer.handles[self.handles_taken].raw_handle();
507        self.handles_taken += 1;
508
509        Ok(handle)
510    }
511
512    fn handles_remaining(&mut self) -> usize {
513        self.buffer.handles.len() - self.handles_taken
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use fdf::CurrentDispatcher;
521    use fdf_env::test::spawn_in_driver;
522    use fidl_next::{ClientDispatcher, ClientEnd, IgnoreEvents};
523    use fidl_next_fuchsia_examples_gizmo::Device;
524
525    #[fuchsia::test]
526    async fn wait_pending_at_dispatcher_shutdown() {
527        spawn_in_driver("driver fidl server", async {
528            let (_server_chan, client_chan) = Channel::create();
529            let client_end: ClientEnd<Device, _> = ClientEnd::<Device, _>::from_untyped(
530                AsyncChannel::new_on_dispatcher(CurrentDispatcher, client_chan),
531            );
532            let client_dispatcher = ClientDispatcher::new(client_end);
533            let _client = client_dispatcher.client();
534            CurrentDispatcher
535                .spawn(async {
536                    println!(
537                        "client task finished: {:?}",
538                        client_dispatcher.run(IgnoreEvents).await.map(|_| ())
539                    );
540                })
541                .unwrap();
542            (_server_chan, _client)
543        });
544    }
545}