fidl_next_protocol/
mpsc.rs
1use core::fmt;
8use core::marker::PhantomData;
9use core::mem::take;
10use core::pin::Pin;
11use core::ptr::NonNull;
12use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use core::task::{Context, Poll};
14use std::sync::{mpsc, Arc};
15
16use fidl_next_codec::decoder::InternalHandleDecoder;
17use fidl_next_codec::{Chunk, DecodeError, Decoder, CHUNK_SIZE};
18use futures::task::AtomicWaker;
19
20use crate::Transport;
21
22struct SharedEnd {
23 sender_count: AtomicUsize,
24 send_waker: AtomicWaker,
25}
26
27struct Shared {
28 is_closed: AtomicBool,
29 ends: [SharedEnd; 2],
30}
31
32impl Shared {
33 fn close(&self) {
34 let was_closed = self.is_closed.swap(true, Ordering::Relaxed);
35 if !was_closed {
36 for end in &self.ends {
37 end.send_waker.wake();
38 }
39 }
40 }
41}
42
43pub struct Mpsc {
45 sender: Sender,
46 receiver: mpsc::Receiver<Vec<Chunk>>,
47}
48
49impl Mpsc {
50 pub fn new() -> (Self, Self) {
52 let shared = Arc::new(Shared {
53 is_closed: AtomicBool::new(false),
54 ends: [
55 SharedEnd { sender_count: AtomicUsize::new(1), send_waker: AtomicWaker::new() },
56 SharedEnd { sender_count: AtomicUsize::new(1), send_waker: AtomicWaker::new() },
57 ],
58 });
59 let (a_send, a_recv) = mpsc::channel();
60 let (b_send, b_recv) = mpsc::channel();
61 (
62 Mpsc {
63 sender: Sender { shared: shared.clone(), end: 0, sender: a_send },
64 receiver: b_recv,
65 },
66 Mpsc { sender: Sender { shared, end: 1, sender: b_send }, receiver: a_recv },
67 )
68 }
69}
70
71#[derive(Debug)]
73pub enum Error {
74 Closed,
76}
77
78impl fmt::Display for Error {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 Self::Closed => write!(f, "the mpsc was closed"),
82 }
83 }
84}
85
86impl core::error::Error for Error {}
87
88pub struct Sender {
90 shared: Arc<Shared>,
91 end: usize,
92 sender: mpsc::Sender<Vec<Chunk>>,
93}
94
95impl Clone for Sender {
96 fn clone(&self) -> Self {
97 self.shared.ends[self.end].sender_count.fetch_add(1, Ordering::Relaxed);
98 Self { shared: self.shared.clone(), end: self.end, sender: self.sender.clone() }
99 }
100}
101
102impl Drop for Sender {
103 fn drop(&mut self) {
104 let senders = self.shared.ends[self.end].sender_count.fetch_sub(1, Ordering::Relaxed);
105 if senders == 1 {
106 self.shared.close();
107 }
108 }
109}
110
111pub struct SendFutureState {
113 buffer: Vec<Chunk>,
114}
115
116pub struct Receiver {
118 shared: Arc<Shared>,
119 end: usize,
120 receiver: mpsc::Receiver<Vec<Chunk>>,
121}
122
123pub struct RecvFutureState {
125 _phantom: PhantomData<()>,
126}
127
128pub struct RecvBuffer {
130 chunks: Vec<Chunk>,
131 chunks_taken: usize,
132}
133
134impl InternalHandleDecoder for RecvBuffer {
135 fn __internal_take_handles(&mut self, _: usize) -> Result<(), DecodeError> {
136 Err(DecodeError::InsufficientHandles)
137 }
138
139 fn __internal_handles_remaining(&self) -> usize {
140 0
141 }
142}
143
144unsafe impl Decoder for RecvBuffer {
145 fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
146 if count > self.chunks.len() - self.chunks_taken {
147 return Err(DecodeError::InsufficientData);
148 }
149
150 let chunks = unsafe { self.chunks.as_mut_ptr().add(self.chunks_taken) };
151 self.chunks_taken += count;
152
153 unsafe { Ok(NonNull::new_unchecked(chunks)) }
154 }
155
156 fn finish(&mut self) -> Result<(), DecodeError> {
157 if self.chunks_taken != self.chunks.len() {
158 return Err(DecodeError::ExtraBytes {
159 num_extra: (self.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
160 });
161 }
162
163 Ok(())
164 }
165}
166
167impl Transport for Mpsc {
168 type Error = Error;
169
170 fn split(self) -> (Self::Sender, Self::Receiver) {
171 let receiver = Receiver {
172 shared: self.sender.shared.clone(),
173 end: self.sender.end,
174 receiver: self.receiver,
175 };
176 (self.sender, receiver)
177 }
178
179 type Sender = Sender;
180 type SendBuffer = Vec<Chunk>;
181 type SendFutureState = SendFutureState;
182
183 fn acquire(_: &Self::Sender) -> Self::SendBuffer {
184 Vec::new()
185 }
186
187 fn begin_send(_: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
188 SendFutureState { buffer }
189 }
190
191 fn poll_send(
192 mut future_state: Pin<&mut SendFutureState>,
193 _: &mut Context<'_>,
194 sender: &Self::Sender,
195 ) -> Poll<Result<(), Error>> {
196 if sender.shared.is_closed.load(Ordering::Relaxed) {
197 return Poll::Ready(Err(Error::Closed));
198 }
199
200 let chunks = take(&mut future_state.buffer);
201 match sender.sender.send(chunks) {
202 Ok(()) => {
203 sender.shared.ends[sender.end].send_waker.wake();
204 Poll::Ready(Ok(()))
205 }
206 Err(_) => Poll::Ready(Err(Error::Closed)),
207 }
208 }
209
210 fn close(sender: &Self::Sender) {
211 sender.shared.close();
212 }
213
214 type Receiver = Receiver;
215 type RecvFutureState = RecvFutureState;
216 type RecvBuffer = RecvBuffer;
217
218 fn begin_recv(_: &mut Self::Receiver) -> Self::RecvFutureState {
219 RecvFutureState { _phantom: PhantomData }
220 }
221
222 fn poll_recv(
223 _: Pin<&mut Self::RecvFutureState>,
224 cx: &mut Context<'_>,
225 receiver: &mut Self::Receiver,
226 ) -> Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
227 if receiver.shared.is_closed.load(Ordering::Relaxed) {
228 return Poll::Ready(Ok(None));
229 }
230
231 receiver.shared.ends[1 - receiver.end].send_waker.register(cx.waker());
232 match receiver.receiver.try_recv() {
233 Ok(chunks) => Poll::Ready(Ok(Some(RecvBuffer { chunks, chunks_taken: 0 }))),
234 Err(mpsc::TryRecvError::Empty) => Poll::Pending,
235 Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Ok(None)),
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use fuchsia_async as fasync;
243
244 use super::Mpsc;
245 use crate::testing::*;
246
247 #[fasync::run_singlethreaded(test)]
248 async fn close_on_drop() {
249 let (client_end, server_end) = Mpsc::new();
250 test_close_on_drop(client_end, server_end).await;
251 }
252
253 #[fasync::run_singlethreaded(test)]
254 async fn send_receive() {
255 let (client_end, server_end) = Mpsc::new();
256 test_one_way(client_end, server_end).await;
257 }
258
259 #[fasync::run_singlethreaded(test)]
260 async fn two_way() {
261 let (client_end, server_end) = Mpsc::new();
262 test_two_way(client_end, server_end).await;
263 }
264
265 #[fasync::run_singlethreaded(test)]
266 async fn multiple_two_way() {
267 let (client_end, server_end) = Mpsc::new();
268 test_multiple_two_way(client_end, server_end).await;
269 }
270
271 #[fasync::run_singlethreaded(test)]
272 async fn event() {
273 let (client_end, server_end) = Mpsc::new();
274 test_event(client_end, server_end).await;
275 }
276}