fidl_next_protocol/fuchsia/
channel.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! A transport implementation which uses Zircon channels.
6
7use core::mem::replace;
8use core::pin::Pin;
9use core::ptr::NonNull;
10use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use core::task::{Context, Poll};
12use std::sync::Arc;
13
14use fidl_next_codec::decoder::InternalHandleDecoder;
15use fidl_next_codec::encoder::InternalHandleEncoder;
16use fidl_next_codec::fuchsia::{HandleDecoder, HandleEncoder};
17use fidl_next_codec::{Chunk, DecodeError, Decoder, EncodeError, Encoder, CHUNK_SIZE};
18use fuchsia_async::{RWHandle, ReadableHandle as _};
19use futures::task::AtomicWaker;
20use zx::sys::{
21    zx_channel_read, zx_channel_write, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_PEER_CLOSED,
22    ZX_ERR_SHOULD_WAIT, ZX_OK,
23};
24use zx::{AsHandleRef as _, Channel, Handle, Status};
25
26use crate::Transport;
27
28struct Shared {
29    is_closed: AtomicBool,
30    sender_count: AtomicUsize,
31    closed_waker: AtomicWaker,
32    channel: RWHandle<Channel>,
33    // TODO: recycle send/recv buffers to reduce allocations
34}
35
36impl Shared {
37    fn new(channel: Channel) -> Self {
38        Self {
39            is_closed: AtomicBool::new(false),
40            sender_count: AtomicUsize::new(1),
41            closed_waker: AtomicWaker::new(),
42            channel: RWHandle::new(channel),
43        }
44    }
45
46    fn close(&self) {
47        self.is_closed.store(true, Ordering::Relaxed);
48        self.closed_waker.wake();
49    }
50}
51
52/// A channel sender.
53pub struct Sender {
54    shared: Arc<Shared>,
55}
56
57impl Drop for Sender {
58    fn drop(&mut self) {
59        let senders = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
60        if senders == 1 {
61            self.shared.close();
62        }
63    }
64}
65
66impl Clone for Sender {
67    fn clone(&self) -> Self {
68        self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
69        Self { shared: self.shared.clone() }
70    }
71}
72
73/// A channel buffer.
74#[derive(Default)]
75pub struct Buffer {
76    handles: Vec<Handle>,
77    chunks: Vec<Chunk>,
78}
79
80impl Buffer {
81    /// New buffer.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Retrieve the handles.
87    pub fn handles(&self) -> &[Handle] {
88        &self.handles
89    }
90
91    /// Retrieve the bytes.
92    pub fn bytes(&self) -> Vec<u8> {
93        self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
94    }
95}
96
97impl InternalHandleEncoder for Buffer {
98    #[inline]
99    fn __internal_handle_count(&self) -> usize {
100        self.handles.len()
101    }
102}
103
104impl Encoder for Buffer {
105    #[inline]
106    fn bytes_written(&self) -> usize {
107        Encoder::bytes_written(&self.chunks)
108    }
109
110    #[inline]
111    fn write_zeroes(&mut self, len: usize) {
112        Encoder::write_zeroes(&mut self.chunks, len)
113    }
114
115    #[inline]
116    fn write(&mut self, bytes: &[u8]) {
117        Encoder::write(&mut self.chunks, bytes)
118    }
119
120    #[inline]
121    fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
122        Encoder::rewrite(&mut self.chunks, pos, bytes)
123    }
124}
125
126impl HandleEncoder for Buffer {
127    fn push_handle(&mut self, handle: Handle) -> Result<(), EncodeError> {
128        self.handles.push(handle);
129        Ok(())
130    }
131
132    fn handles_pushed(&self) -> usize {
133        self.handles.len()
134    }
135}
136
137/// The state for a channel send future.
138pub struct SendFutureState {
139    buffer: Buffer,
140}
141
142/// A channel receiver.
143pub struct Receiver {
144    shared: Arc<Shared>,
145}
146
147/// The state for a channel receive future.
148pub struct RecvFutureState {
149    buffer: Option<Buffer>,
150}
151
152/// A channel receive buffer.
153pub struct RecvBuffer {
154    buffer: Buffer,
155    chunks_taken: usize,
156    handles_taken: usize,
157}
158
159unsafe impl Decoder for RecvBuffer {
160    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
161        if count > self.buffer.chunks.len() - self.chunks_taken {
162            return Err(DecodeError::InsufficientData);
163        }
164
165        let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
166        self.chunks_taken += count;
167
168        unsafe { Ok(NonNull::new_unchecked(chunks)) }
169    }
170
171    fn finish(&mut self) -> Result<(), DecodeError> {
172        if self.chunks_taken != self.buffer.chunks.len() {
173            return Err(DecodeError::ExtraBytes {
174                num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
175            });
176        }
177
178        if self.handles_taken != self.buffer.handles.len() {
179            return Err(DecodeError::ExtraHandles {
180                num_extra: self.buffer.handles.len() - self.handles_taken,
181            });
182        }
183
184        Ok(())
185    }
186}
187
188impl InternalHandleDecoder for RecvBuffer {
189    fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
190        if count > self.buffer.handles.len() - self.handles_taken {
191            return Err(DecodeError::InsufficientHandles);
192        }
193
194        for i in self.handles_taken..self.handles_taken + count {
195            let handle = replace(&mut self.buffer.handles[i], Handle::invalid());
196            drop(handle);
197        }
198        self.handles_taken += count;
199
200        Ok(())
201    }
202
203    fn __internal_handles_remaining(&self) -> usize {
204        self.buffer.handles.len() - self.handles_taken
205    }
206}
207
208impl HandleDecoder for RecvBuffer {
209    fn take_handle(&mut self) -> Result<Handle, DecodeError> {
210        if self.handles_taken >= self.buffer.handles.len() {
211            return Err(DecodeError::InsufficientHandles);
212        }
213
214        let handle = replace(&mut self.buffer.handles[self.handles_taken], Handle::invalid());
215        self.handles_taken += 1;
216
217        Ok(handle)
218    }
219
220    fn handles_remaining(&mut self) -> usize {
221        self.buffer.handles.len() - self.handles_taken
222    }
223}
224
225impl Transport for Channel {
226    type Error = Status;
227
228    fn split(self) -> (Self::Sender, Self::Receiver) {
229        let shared = Arc::new(Shared::new(self));
230        (Sender { shared: shared.clone() }, Receiver { shared })
231    }
232
233    type Sender = Sender;
234    type SendBuffer = Buffer;
235    type SendFutureState = SendFutureState;
236
237    fn acquire(_: &Self::Sender) -> Self::SendBuffer {
238        Buffer::new()
239    }
240
241    fn begin_send(_: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
242        SendFutureState { buffer }
243    }
244
245    fn poll_send(
246        mut future_state: Pin<&mut Self::SendFutureState>,
247        _: &mut Context<'_>,
248        sender: &Self::Sender,
249    ) -> Poll<Result<(), Self::Error>> {
250        let result = unsafe {
251            zx_channel_write(
252                sender.shared.channel.get_ref().raw_handle(),
253                0,
254                future_state.buffer.chunks.as_ptr().cast::<u8>(),
255                (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
256                future_state.buffer.handles.as_ptr().cast(),
257                future_state.buffer.handles.len() as u32,
258            )
259        };
260
261        if result == ZX_OK {
262            // Handles were written to the channel, so we must not drop them.
263            unsafe {
264                future_state.buffer.handles.set_len(0);
265            }
266            Poll::Ready(Ok(()))
267        } else {
268            Poll::Ready(Err(Status::from_raw(result)))
269        }
270    }
271
272    fn close(sender: &Self::Sender) {
273        sender.shared.close();
274    }
275
276    type Receiver = Receiver;
277    type RecvFutureState = RecvFutureState;
278    type RecvBuffer = RecvBuffer;
279
280    fn begin_recv(_: &mut Self::Receiver) -> Self::RecvFutureState {
281        RecvFutureState { buffer: Some(Buffer::new()) }
282    }
283
284    fn poll_recv(
285        mut future_state: Pin<&mut Self::RecvFutureState>,
286        cx: &mut Context<'_>,
287        receiver: &mut Self::Receiver,
288    ) -> Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
289        let buffer = future_state.buffer.as_mut().unwrap();
290
291        let mut actual_bytes = 0;
292        let mut actual_handles = 0;
293
294        loop {
295            let result = unsafe {
296                zx_channel_read(
297                    receiver.shared.channel.get_ref().raw_handle(),
298                    0,
299                    buffer.chunks.as_mut_ptr().cast(),
300                    buffer.handles.as_mut_ptr().cast(),
301                    (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
302                    buffer.handles.capacity() as u32,
303                    &mut actual_bytes,
304                    &mut actual_handles,
305                )
306            };
307
308            match result {
309                ZX_OK => {
310                    unsafe {
311                        buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
312                        buffer.handles.set_len(actual_handles as usize);
313                    }
314                    return Poll::Ready(Ok(Some(RecvBuffer {
315                        buffer: future_state.buffer.take().unwrap(),
316                        chunks_taken: 0,
317                        handles_taken: 0,
318                    })));
319                }
320                ZX_ERR_PEER_CLOSED => return Poll::Ready(Ok(None)),
321                ZX_ERR_BUFFER_TOO_SMALL => {
322                    let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
323                    buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
324                    buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
325                }
326                ZX_ERR_SHOULD_WAIT => {
327                    if matches!(receiver.shared.channel.need_readable(cx)?, Poll::Pending) {
328                        receiver.shared.closed_waker.register(cx.waker());
329                        if receiver.shared.is_closed.load(Ordering::Relaxed) {
330                            return Poll::Ready(Ok(None));
331                        }
332                        return Poll::Pending;
333                    }
334                }
335                raw => return Poll::Ready(Err(Status::from_raw(raw))),
336            }
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use fuchsia_async as fasync;
344    use zx::Channel;
345
346    use crate::testing::transport::*;
347
348    #[fasync::run_singlethreaded(test)]
349    async fn close_on_drop() {
350        let (client_end, server_end) = Channel::create();
351        test_close_on_drop(client_end, server_end).await;
352    }
353
354    #[fasync::run_singlethreaded(test)]
355    async fn one_way() {
356        let (client_end, server_end) = Channel::create();
357        test_one_way(client_end, server_end).await;
358    }
359
360    #[fasync::run_singlethreaded(test)]
361    async fn two_way() {
362        let (client_end, server_end) = Channel::create();
363        test_two_way(client_end, server_end).await;
364    }
365
366    #[fasync::run_singlethreaded(test)]
367    async fn multiple_two_way() {
368        let (client_end, server_end) = Channel::create();
369        test_multiple_two_way(client_end, server_end).await;
370    }
371
372    #[fasync::run_singlethreaded(test)]
373    async fn event() {
374        let (client_end, server_end) = Channel::create();
375        test_event(client_end, server_end).await;
376    }
377}