fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_next_codec::{Encode, EncodeError, EncoderExt};
12use pin_project::{pin_project, pinned_drop};
13
14use crate::concurrency::sync::{Arc, Mutex};
15use crate::endpoints::connection::{Connection, ORDINAL_EPITAPH};
16use crate::endpoints::lockers::{LockerError, Lockers};
17use crate::{ProtocolError, SendFuture, Transport, decode_epitaph, decode_header, encode_header};
18
19struct ClientInner<T: Transport> {
20 connection: Connection<T>,
21 responses: Mutex<Lockers<T::RecvBuffer>>,
22}
23
24impl<T: Transport> ClientInner<T> {
25 fn new(shared: T::Shared) -> Self {
26 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
27 }
28}
29
30pub struct Client<T: Transport> {
32 inner: Arc<ClientInner<T>>,
33}
34
35impl<T: Transport> Drop for Client<T> {
36 fn drop(&mut self) {
37 if Arc::strong_count(&self.inner) == 2 {
38 self.close();
41 }
42 }
43}
44
45impl<T: Transport> Client<T> {
46 pub fn close(&self) {
48 self.inner.connection.stop();
49 }
50
51 pub fn send_one_way<M>(
53 &self,
54 ordinal: u64,
55 request: M,
56 ) -> Result<SendFuture<'_, T>, EncodeError>
57 where
58 M: Encode<T::SendBuffer>,
59 {
60 self.send_message(0, ordinal, request)
61 }
62
63 pub fn send_two_way<M>(
65 &self,
66 ordinal: u64,
67 request: M,
68 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
69 where
70 M: Encode<T::SendBuffer>,
71 {
72 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
73
74 match self.send_message(index + 1, ordinal, request) {
76 Ok(send_future) => {
77 Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
78 }
79 Err(e) => {
80 self.inner.responses.lock().unwrap().free(index);
81 Err(e)
82 }
83 }
84 }
85
86 fn send_message<M>(
87 &self,
88 txid: u32,
89 ordinal: u64,
90 message: M,
91 ) -> Result<SendFuture<'_, T>, EncodeError>
92 where
93 M: Encode<T::SendBuffer>,
94 {
95 self.inner.connection.send_message(|buffer| {
96 encode_header::<T>(buffer, txid, ordinal)?;
97 buffer.encode_next(message)
98 })
99 }
100}
101
102impl<T: Transport> Clone for Client<T> {
103 fn clone(&self) -> Self {
104 Self { inner: self.inner.clone() }
105 }
106}
107
108pub struct TwoWayResponseFuture<'a, T: Transport> {
110 inner: &'a ClientInner<T>,
111 index: Option<u32>,
112}
113
114impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
115 fn drop(&mut self) {
116 if let Some(index) = self.index {
118 let mut responses = self.inner.responses.lock().unwrap();
119 if responses.get(index).unwrap().cancel() {
120 responses.free(index);
121 }
122 }
123 }
124}
125
126impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
127 type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
128
129 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 let this = Pin::into_inner(self);
131 let Some(index) = this.index else {
132 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
133 };
134
135 let mut responses = this.inner.responses.lock().unwrap();
136 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
137 Ok(ready)
138 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
139 Err(termination_reason)
140 } else {
141 return Poll::Pending;
142 };
143
144 responses.free(index);
145 this.index = None;
146 Poll::Ready(ready)
147 }
148}
149
150#[pin_project(PinnedDrop)]
152pub struct TwoWayRequestFuture<'a, T: Transport> {
153 inner: &'a ClientInner<T>,
154 index: Option<u32>,
155 #[pin]
156 send_future: SendFuture<'a, T>,
157}
158
159#[pinned_drop]
160impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
161 fn drop(self: Pin<&mut Self>) {
162 if let Some(index) = self.index {
163 let mut responses = self.inner.responses.lock().unwrap();
164
165 responses.free(index);
168 }
169 }
170}
171
172impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
173 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
174
175 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176 let this = self.project();
177
178 let Some(index) = *this.index else {
179 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
180 };
181
182 let result = ready!(this.send_future.poll(cx));
183 *this.index = None;
184 if let Err(error) = result {
185 this.inner.responses.lock().unwrap().free(index);
187 Poll::Ready(Err(error))
188 } else {
189 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
190 }
191 }
192}
193
194pub trait ClientHandler<T: Transport> {
196 fn on_event(
203 &mut self,
204 ordinal: u64,
205 buffer: T::RecvBuffer,
206 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
207}
208
209pub struct ClientDispatcher<T: Transport> {
218 inner: Arc<ClientInner<T>>,
219 exclusive: T::Exclusive,
220 is_terminated: bool,
221}
222
223impl<T: Transport> Drop for ClientDispatcher<T> {
224 fn drop(&mut self) {
225 if !self.is_terminated {
226 unsafe {
228 self.terminate(ProtocolError::Stopped);
229 }
230 }
231 }
232}
233
234impl<T: Transport> ClientDispatcher<T> {
235 pub fn new(transport: T) -> Self {
237 let (shared, exclusive) = transport.split();
238 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
239 }
240
241 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
245 unsafe {
247 self.inner.connection.terminate(error);
248 }
249 self.inner.responses.lock().unwrap().wake_all();
250 }
251
252 pub fn client(&self) -> Client<T> {
256 Client { inner: self.inner.clone() }
257 }
258
259 pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
261 where
262 H: ClientHandler<T>,
263 {
264 let error = loop {
270 let result = unsafe { self.run_one(&mut handler).await };
272 if let Err(error) = result {
273 break error;
274 }
275 };
276
277 unsafe {
279 self.terminate(error.clone());
280 }
281 self.is_terminated = true;
282
283 match error {
284 ProtocolError::Stopped => Ok(handler),
287
288 _ => Err(error),
290 }
291 }
292
293 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
297 where
298 H: ClientHandler<T>,
299 {
300 let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
302
303 let (txid, ordinal) =
304 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
305
306 if ordinal == ORDINAL_EPITAPH {
307 let epitaph =
308 decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
309 return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
310 } else if txid == 0 {
311 handler.on_event(ordinal, buffer).await?;
312 } else {
313 let mut responses = self.inner.responses.lock().unwrap();
314 let locker = responses
315 .get(txid - 1)
316 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
317
318 match locker.write(ordinal, buffer) {
319 Ok(false) => (),
321 Ok(true) => responses.free(txid - 1),
323 Err(LockerError::NotWriteable) => {
324 return Err(ProtocolError::UnrequestedResponse { txid });
325 }
326 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
327 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
328 }
329 }
330 }
331
332 Ok(())
333 }
334
335 pub async fn run_client(self) -> Result<(), ProtocolError<T::Error>> {
337 self.run(IgnoreEvents).await.map(|_| ())
338 }
339}
340
341pub struct IgnoreEvents;
343
344impl<T: Transport> ClientHandler<T> for IgnoreEvents {
345 async fn on_event(&mut self, _: u64, _: T::RecvBuffer) -> Result<(), ProtocolError<T::Error>> {
346 Ok(())
347 }
348}