1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use std::sync::{Arc, Mutex};
11
12use fidl_next_codec::{Encode, EncodeError, EncoderExt};
13
14use crate::lockers::Lockers;
15use crate::{decode_header, encode_header, ProtocolError, SendFuture, Transport, TransportExt};
16
17use super::lockers::LockerError;
18
19struct Shared<T: Transport> {
20 responses: Mutex<Lockers<T::RecvBuffer>>,
21}
22
23impl<T: Transport> Shared<T> {
24 fn new() -> Self {
25 Self { responses: Mutex::new(Lockers::new()) }
26 }
27}
28
29pub struct ClientSender<T: Transport> {
31 shared: Arc<Shared<T>>,
32 sender: T::Sender,
33}
34
35impl<T: Transport> ClientSender<T> {
36 pub fn close(&self) {
38 T::close(&self.sender);
39 }
40
41 pub fn send_one_way<M>(
43 &self,
44 ordinal: u64,
45 request: &mut M,
46 ) -> Result<SendFuture<'_, T>, EncodeError>
47 where
48 M: Encode<T::SendBuffer>,
49 {
50 self.send_message(0, ordinal, request)
51 }
52
53 pub fn send_two_way<M>(
55 &self,
56 ordinal: u64,
57 request: &mut M,
58 ) -> Result<ResponseFuture<'_, T>, EncodeError>
59 where
60 M: Encode<T::SendBuffer>,
61 {
62 let index = self.shared.responses.lock().unwrap().alloc(ordinal);
63
64 match self.send_message(index + 1, ordinal, request) {
66 Ok(future) => Ok(ResponseFuture {
67 shared: &self.shared,
68 index,
69 state: ResponseFutureState::Sending(future),
70 }),
71 Err(e) => {
72 self.shared.responses.lock().unwrap().free(index);
73 Err(e)
74 }
75 }
76 }
77
78 fn send_message<M>(
79 &self,
80 txid: u32,
81 ordinal: u64,
82 message: &mut M,
83 ) -> Result<SendFuture<'_, T>, EncodeError>
84 where
85 M: Encode<T::SendBuffer>,
86 {
87 let mut buffer = T::acquire(&self.sender);
88 encode_header::<T>(&mut buffer, txid, ordinal)?;
89 buffer.encode_next(message)?;
90 Ok(T::send(&self.sender, buffer))
91 }
92}
93
94impl<T: Transport> Clone for ClientSender<T> {
95 fn clone(&self) -> Self {
96 Self { shared: self.shared.clone(), sender: self.sender.clone() }
97 }
98}
99
100enum ResponseFutureState<'a, T: Transport> {
101 Sending(SendFuture<'a, T>),
102 Receiving,
103 Completed,
106}
107
108pub struct ResponseFuture<'a, T: Transport> {
110 shared: &'a Shared<T>,
111 index: u32,
112 state: ResponseFutureState<'a, T>,
113}
114
115impl<T: Transport> Drop for ResponseFuture<'_, T> {
116 fn drop(&mut self) {
117 let mut responses = self.shared.responses.lock().unwrap();
118 match self.state {
119 ResponseFutureState::Sending(_) => responses.free(self.index),
122 ResponseFutureState::Receiving => {
123 if responses.get(self.index).unwrap().cancel() {
124 responses.free(self.index);
125 }
126 }
127 ResponseFutureState::Completed => (),
129 }
130 }
131}
132
133impl<T: Transport> ResponseFuture<'_, T> {
134 fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
135 let mut responses = self.shared.responses.lock().unwrap();
136 if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
137 responses.free(self.index);
138 self.state = ResponseFutureState::Completed;
139 Poll::Ready(Ok(ready))
140 } else {
141 Poll::Pending
142 }
143 }
144}
145
146impl<T: Transport> Future for ResponseFuture<'_, T> {
147 type Output = Result<T::RecvBuffer, T::Error>;
148
149 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 let this = unsafe { Pin::into_inner_unchecked(self) };
152
153 match &mut this.state {
154 ResponseFutureState::Sending(future) => {
155 let pinned = unsafe { Pin::new_unchecked(future) };
157 match pinned.poll(cx) {
158 Poll::Pending => Poll::Pending,
160 Poll::Ready(Ok(())) => {
161 this.state = ResponseFutureState::Receiving;
164 this.poll_receiving(cx)
165 }
166 Poll::Ready(Err(e)) => {
167 this.shared.responses.lock().unwrap().free(this.index);
171 this.state = ResponseFutureState::Completed;
172 Poll::Ready(Err(e))
173 }
174 }
175 }
176 ResponseFutureState::Receiving => this.poll_receiving(cx),
177 ResponseFutureState::Completed => unreachable!(),
180 }
181 }
182}
183
184pub trait ClientHandler<T: Transport> {
186 fn on_event(&mut self, sender: &ClientSender<T>, ordinal: u64, buffer: T::RecvBuffer);
192}
193
194pub struct Client<T: Transport> {
198 sender: ClientSender<T>,
199 receiver: T::Receiver,
200}
201
202impl<T: Transport> Client<T> {
203 pub fn new(transport: T) -> Self {
205 let (sender, receiver) = transport.split();
206 let shared = Arc::new(Shared::new());
207 Self { sender: ClientSender { shared, sender }, receiver }
208 }
209
210 pub fn sender(&self) -> &ClientSender<T> {
212 &self.sender
213 }
214
215 pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
217 where
218 H: ClientHandler<T>,
219 {
220 let result = self.run_to_completion(&mut handler).await;
221 self.sender.shared.responses.lock().unwrap().wake_all();
222
223 result
224 }
225
226 pub async fn run_sender(&mut self) -> Result<(), ProtocolError<T::Error>> {
228 self.run(IgnoreEvents).await
229 }
230
231 async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
232 where
233 H: ClientHandler<T>,
234 {
235 while let Some(mut buffer) =
236 T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
237 {
238 let (txid, ordinal) =
239 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
240 if txid == 0 {
241 handler.on_event(&self.sender, ordinal, buffer);
242 } else {
243 let mut responses = self.sender.shared.responses.lock().unwrap();
244 let locker = responses
245 .get(txid - 1)
246 .ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
247
248 match locker.write(ordinal, buffer) {
249 Ok(false) => (),
251 Ok(true) => responses.free(txid - 1),
253 Err(LockerError::NotWriteable) => {
254 return Err(ProtocolError::UnrequestedResponse(txid));
255 }
256 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
257 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
258 }
259 }
260 }
261 }
262
263 Ok(())
264 }
265}
266
267pub struct IgnoreEvents;
269
270impl<T: Transport> ClientHandler<T> for IgnoreEvents {
271 fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
272}