fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_next_codec::{Encode, EncodeError, EncoderExt};
12
13use crate::concurrency::sync::{Arc, Mutex};
14
15use crate::{ProtocolError, Transport, decode_epitaph, decode_header, encode_header};
16
17use super::connection::{Connection, ORDINAL_EPITAPH, SendFuture};
18use super::lockers::{LockerError, Lockers};
19
20struct ClientSenderInner<T: Transport> {
21 connection: Connection<T>,
22 responses: Mutex<Lockers<T::RecvBuffer>>,
23}
24
25impl<T: Transport> ClientSenderInner<T> {
26 fn new(shared: T::Shared) -> Self {
27 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
28 }
29}
30
31pub struct ClientSender<T: Transport> {
33 inner: Arc<ClientSenderInner<T>>,
34}
35
36impl<T: Transport> Drop for ClientSender<T> {
37 fn drop(&mut self) {
38 if Arc::strong_count(&self.inner) == 2 {
39 self.close();
42 }
43 }
44}
45
46impl<T: Transport> ClientSender<T> {
47 pub fn close(&self) {
49 self.inner.connection.stop();
50 }
51
52 pub fn send_one_way<M>(
54 &self,
55 ordinal: u64,
56 request: M,
57 ) -> Result<SendFuture<'_, T>, EncodeError>
58 where
59 M: Encode<T::SendBuffer>,
60 {
61 self.send_message(0, ordinal, request)
62 }
63
64 pub fn send_two_way<M>(
66 &self,
67 ordinal: u64,
68 request: M,
69 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
70 where
71 M: Encode<T::SendBuffer>,
72 {
73 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
74
75 match self.send_message(index + 1, ordinal, request) {
77 Ok(send_future) => {
78 Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
79 }
80 Err(e) => {
81 self.inner.responses.lock().unwrap().free(index);
82 Err(e)
83 }
84 }
85 }
86
87 fn send_message<M>(
88 &self,
89 txid: u32,
90 ordinal: u64,
91 message: M,
92 ) -> Result<SendFuture<'_, T>, EncodeError>
93 where
94 M: Encode<T::SendBuffer>,
95 {
96 self.inner.connection.send_with(|buffer| {
97 encode_header::<T>(buffer, txid, ordinal)?;
98 buffer.encode_next(message)
99 })
100 }
101}
102
103impl<T: Transport> Clone for ClientSender<T> {
104 fn clone(&self) -> Self {
105 Self { inner: self.inner.clone() }
106 }
107}
108
109pub struct TwoWayResponseFuture<'a, T: Transport> {
111 inner: &'a ClientSenderInner<T>,
112 index: Option<u32>,
113}
114
115impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
116 fn drop(&mut self) {
117 if let Some(index) = self.index {
119 let mut responses = self.inner.responses.lock().unwrap();
120 if responses.get(index).unwrap().cancel() {
121 responses.free(index);
122 }
123 }
124 }
125}
126
127impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
128 type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
129
130 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
131 let this = Pin::into_inner(self);
132 let Some(index) = this.index else {
133 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
134 };
135
136 let mut responses = this.inner.responses.lock().unwrap();
137 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
138 Ok(ready)
139 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
140 Err(termination_reason)
141 } else {
142 return Poll::Pending;
143 };
144
145 responses.free(index);
146 this.index = None;
147 Poll::Ready(ready)
148 }
149}
150
151pub struct TwoWayRequestFuture<'a, T: Transport> {
153 inner: &'a ClientSenderInner<T>,
154 index: Option<u32>,
155 send_future: SendFuture<'a, T>,
156}
157
158impl<T: Transport> Drop for TwoWayRequestFuture<'_, T> {
159 fn drop(&mut self) {
160 if let Some(index) = self.index {
161 let mut responses = self.inner.responses.lock().unwrap();
162
163 responses.free(index);
166 }
167 }
168}
169
170impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
171 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
172
173 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174 let this = unsafe { Pin::into_inner_unchecked(self) };
176
177 let Some(index) = this.index else {
178 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
179 };
180
181 let send_future = unsafe { Pin::new_unchecked(&mut this.send_future) };
183
184 let result = ready!(send_future.poll(cx));
185 this.index = None;
186 if let Err(error) = result {
187 this.inner.responses.lock().unwrap().free(index);
189 Poll::Ready(Err(error))
190 } else {
191 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
192 }
193 }
194}
195
196pub trait ClientHandler<T: Transport> {
198 fn on_event(
203 &mut self,
204 sender: &ClientSender<T>,
205 ordinal: u64,
206 buffer: T::RecvBuffer,
207 ) -> impl Future<Output = ()> + Send;
208}
209
210pub struct Client<T: Transport> {
214 sender: ClientSender<T>,
215 exclusive: T::Exclusive,
216 is_terminated: bool,
217}
218
219impl<T: Transport> Drop for Client<T> {
220 fn drop(&mut self) {
221 if !self.is_terminated {
222 unsafe {
224 self.terminate(ProtocolError::Stopped);
225 }
226 }
227 }
228}
229
230impl<T: Transport> Client<T> {
231 pub fn new(transport: T) -> Self {
233 let (shared, exclusive) = transport.split();
234 let inner = Arc::new(ClientSenderInner::new(shared));
235 Self { sender: ClientSender { inner }, exclusive, is_terminated: false }
236 }
237
238 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
242 unsafe {
244 self.sender.inner.connection.terminate(error);
245 }
246 self.sender.inner.responses.lock().unwrap().wake_all();
247 }
248
249 pub fn sender(&self) -> &ClientSender<T> {
251 &self.sender
252 }
253
254 pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
256 where
257 H: ClientHandler<T>,
258 {
259 let error = loop {
265 let result = unsafe { self.run_one(&mut handler).await };
267 if let Err(error) = result {
268 break error;
269 }
270 };
271
272 unsafe {
274 self.terminate(error.clone());
275 }
276 self.is_terminated = true;
277
278 match error {
279 ProtocolError::Stopped => Ok(handler),
282
283 _ => Err(error),
285 }
286 }
287
288 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
292 where
293 H: ClientHandler<T>,
294 {
295 let mut buffer = unsafe { self.sender.inner.connection.recv(&mut self.exclusive).await? };
297
298 let (txid, ordinal) =
299 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
300
301 if ordinal == ORDINAL_EPITAPH {
302 let epitaph =
303 decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
304 return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
305 } else if txid == 0 {
306 handler.on_event(&self.sender, ordinal, buffer).await;
307 } else {
308 let mut responses = self.sender.inner.responses.lock().unwrap();
309 let locker = responses
310 .get(txid - 1)
311 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
312
313 match locker.write(ordinal, buffer) {
314 Ok(false) => (),
316 Ok(true) => responses.free(txid - 1),
318 Err(LockerError::NotWriteable) => {
319 return Err(ProtocolError::UnrequestedResponse { txid });
320 }
321 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
322 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
323 }
324 }
325 }
326
327 Ok(())
328 }
329
330 pub async fn run_sender(self) -> Result<(), ProtocolError<T::Error>> {
332 self.run(IgnoreEvents).await.map(|_| ())
333 }
334}
335
336pub struct IgnoreEvents;
338
339impl<T: Transport> ClientHandler<T> for IgnoreEvents {
340 async fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
341}