fidl_next_protocol/
mpsc.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! A basic [`Transport`] implementation based on MPSC channels.
6
7use core::fmt;
8use core::marker::PhantomData;
9use core::mem::take;
10use core::pin::Pin;
11use core::ptr::NonNull;
12use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use core::task::{Context, Poll};
14use std::sync::{mpsc, Arc};
15
16use fidl_next_codec::decoder::InternalHandleDecoder;
17use fidl_next_codec::{Chunk, DecodeError, Decoder, CHUNK_SIZE};
18use futures::task::AtomicWaker;
19
20use crate::Transport;
21
22struct SharedEnd {
23    sender_count: AtomicUsize,
24    send_waker: AtomicWaker,
25}
26
27struct Shared {
28    is_closed: AtomicBool,
29    ends: [SharedEnd; 2],
30}
31
32impl Shared {
33    fn close(&self) {
34        let was_closed = self.is_closed.swap(true, Ordering::Relaxed);
35        if !was_closed {
36            for end in &self.ends {
37                end.send_waker.wake();
38            }
39        }
40    }
41}
42
43/// A paired mpsc transport.
44pub struct Mpsc {
45    sender: Sender,
46    receiver: mpsc::Receiver<Vec<Chunk>>,
47}
48
49impl Mpsc {
50    /// Creates two mpscs which can communicate with each other.
51    pub fn new() -> (Self, Self) {
52        let shared = Arc::new(Shared {
53            is_closed: AtomicBool::new(false),
54            ends: [
55                SharedEnd { sender_count: AtomicUsize::new(1), send_waker: AtomicWaker::new() },
56                SharedEnd { sender_count: AtomicUsize::new(1), send_waker: AtomicWaker::new() },
57            ],
58        });
59        let (a_send, a_recv) = mpsc::channel();
60        let (b_send, b_recv) = mpsc::channel();
61        (
62            Mpsc {
63                sender: Sender { shared: shared.clone(), end: 0, sender: a_send },
64                receiver: b_recv,
65            },
66            Mpsc { sender: Sender { shared, end: 1, sender: b_send }, receiver: a_recv },
67        )
68    }
69}
70
71/// The error type for paired mpsc transports.
72#[derive(Debug)]
73pub enum Error {
74    /// The mpsc was closed.
75    Closed,
76}
77
78impl fmt::Display for Error {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Self::Closed => write!(f, "the mpsc was closed"),
82        }
83    }
84}
85
86impl core::error::Error for Error {}
87
88/// The send end of a paired mpsc transport.
89pub struct Sender {
90    shared: Arc<Shared>,
91    end: usize,
92    sender: mpsc::Sender<Vec<Chunk>>,
93}
94
95impl Clone for Sender {
96    fn clone(&self) -> Self {
97        self.shared.ends[self.end].sender_count.fetch_add(1, Ordering::Relaxed);
98        Self { shared: self.shared.clone(), end: self.end, sender: self.sender.clone() }
99    }
100}
101
102impl Drop for Sender {
103    fn drop(&mut self) {
104        let senders = self.shared.ends[self.end].sender_count.fetch_sub(1, Ordering::Relaxed);
105        if senders == 1 {
106            self.shared.close();
107        }
108    }
109}
110
111/// The send future for a paired mpsc transport.
112pub struct SendFutureState {
113    buffer: Vec<Chunk>,
114}
115
116/// The receive end of a paired mpsc transport.
117pub struct Receiver {
118    shared: Arc<Shared>,
119    end: usize,
120    receiver: mpsc::Receiver<Vec<Chunk>>,
121}
122
123/// The receive future for a paired mpsc transport.
124pub struct RecvFutureState {
125    _phantom: PhantomData<()>,
126}
127
128/// A received message buffer.
129pub struct RecvBuffer {
130    chunks: Vec<Chunk>,
131    chunks_taken: usize,
132}
133
134impl InternalHandleDecoder for RecvBuffer {
135    fn __internal_take_handles(&mut self, _: usize) -> Result<(), DecodeError> {
136        Err(DecodeError::InsufficientHandles)
137    }
138
139    fn __internal_handles_remaining(&self) -> usize {
140        0
141    }
142}
143
144unsafe impl Decoder for RecvBuffer {
145    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
146        if count > self.chunks.len() - self.chunks_taken {
147            return Err(DecodeError::InsufficientData);
148        }
149
150        let chunks = unsafe { self.chunks.as_mut_ptr().add(self.chunks_taken) };
151        self.chunks_taken += count;
152
153        unsafe { Ok(NonNull::new_unchecked(chunks)) }
154    }
155
156    fn finish(&mut self) -> Result<(), DecodeError> {
157        if self.chunks_taken != self.chunks.len() {
158            return Err(DecodeError::ExtraBytes {
159                num_extra: (self.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
160            });
161        }
162
163        Ok(())
164    }
165}
166
167impl Transport for Mpsc {
168    type Error = Error;
169
170    fn split(self) -> (Self::Sender, Self::Receiver) {
171        let receiver = Receiver {
172            shared: self.sender.shared.clone(),
173            end: self.sender.end,
174            receiver: self.receiver,
175        };
176        (self.sender, receiver)
177    }
178
179    type Sender = Sender;
180    type SendBuffer = Vec<Chunk>;
181    type SendFutureState = SendFutureState;
182
183    fn acquire(_: &Self::Sender) -> Self::SendBuffer {
184        Vec::new()
185    }
186
187    fn begin_send(_: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
188        SendFutureState { buffer }
189    }
190
191    fn poll_send(
192        mut future_state: Pin<&mut SendFutureState>,
193        _: &mut Context<'_>,
194        sender: &Self::Sender,
195    ) -> Poll<Result<(), Error>> {
196        if sender.shared.is_closed.load(Ordering::Relaxed) {
197            return Poll::Ready(Err(Error::Closed));
198        }
199
200        let chunks = take(&mut future_state.buffer);
201        match sender.sender.send(chunks) {
202            Ok(()) => {
203                sender.shared.ends[sender.end].send_waker.wake();
204                Poll::Ready(Ok(()))
205            }
206            Err(_) => Poll::Ready(Err(Error::Closed)),
207        }
208    }
209
210    fn close(sender: &Self::Sender) {
211        sender.shared.close();
212    }
213
214    type Receiver = Receiver;
215    type RecvFutureState = RecvFutureState;
216    type RecvBuffer = RecvBuffer;
217
218    fn begin_recv(_: &mut Self::Receiver) -> Self::RecvFutureState {
219        RecvFutureState { _phantom: PhantomData }
220    }
221
222    fn poll_recv(
223        _: Pin<&mut Self::RecvFutureState>,
224        cx: &mut Context<'_>,
225        receiver: &mut Self::Receiver,
226    ) -> Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
227        if receiver.shared.is_closed.load(Ordering::Relaxed) {
228            return Poll::Ready(Ok(None));
229        }
230
231        receiver.shared.ends[1 - receiver.end].send_waker.register(cx.waker());
232        match receiver.receiver.try_recv() {
233            Ok(chunks) => Poll::Ready(Ok(Some(RecvBuffer { chunks, chunks_taken: 0 }))),
234            Err(mpsc::TryRecvError::Empty) => Poll::Pending,
235            Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Ok(None)),
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use fuchsia_async as fasync;
243
244    use super::Mpsc;
245    use crate::testing::*;
246
247    #[fasync::run_singlethreaded(test)]
248    async fn close_on_drop() {
249        let (client_end, server_end) = Mpsc::new();
250        test_close_on_drop(client_end, server_end).await;
251    }
252
253    #[fasync::run_singlethreaded(test)]
254    async fn send_receive() {
255        let (client_end, server_end) = Mpsc::new();
256        test_one_way(client_end, server_end).await;
257    }
258
259    #[fasync::run_singlethreaded(test)]
260    async fn two_way() {
261        let (client_end, server_end) = Mpsc::new();
262        test_two_way(client_end, server_end).await;
263    }
264
265    #[fasync::run_singlethreaded(test)]
266    async fn multiple_two_way() {
267        let (client_end, server_end) = Mpsc::new();
268        test_multiple_two_way(client_end, server_end).await;
269    }
270
271    #[fasync::run_singlethreaded(test)]
272    async fn event() {
273        let (client_end, server_end) = Mpsc::new();
274        test_event(client_end, server_end).await;
275    }
276}