fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_constants::EPITAPH_ORDINAL;
12use fidl_next_codec::{Constrained, Encode, EncodeError, EncoderExt, Wire};
13use pin_project::{pin_project, pinned_drop};
14
15use crate::concurrency::sync::{Arc, Mutex};
16use crate::endpoints::connection::Connection;
17use crate::endpoints::lockers::{LockerError, Lockers};
18use crate::{
19 Flexibility, ProtocolError, SendFuture, Transport, decode_epitaph, decode_header, encode_header,
20};
21
22struct ClientInner<T: Transport> {
23 connection: Connection<T>,
24 responses: Mutex<Lockers<T::RecvBuffer>>,
25}
26
27impl<T: Transport> ClientInner<T> {
28 fn new(shared: T::Shared) -> Self {
29 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
30 }
31}
32
33pub struct Client<T: Transport> {
35 inner: Arc<ClientInner<T>>,
36}
37
38impl<T: Transport> Drop for Client<T> {
39 fn drop(&mut self) {
40 if Arc::strong_count(&self.inner) == 2 {
41 self.close();
44 }
45 }
46}
47
48impl<T: Transport> Client<T> {
49 pub fn close(&self) {
51 self.inner.connection.stop();
52 }
53
54 pub fn send_one_way<W>(
56 &self,
57 ordinal: u64,
58 flexibility: Flexibility,
59 request: impl Encode<W, T::SendBuffer>,
60 ) -> Result<SendFuture<'_, T>, EncodeError>
61 where
62 W: Constrained<Constraint = ()> + Wire,
63 {
64 self.send_message(0, ordinal, flexibility, request)
65 }
66
67 pub fn send_two_way<W>(
69 &self,
70 ordinal: u64,
71 flexibility: Flexibility,
72 request: impl Encode<W, T::SendBuffer>,
73 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
74 where
75 W: Constrained<Constraint = ()> + Wire,
76 {
77 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
78
79 match self.send_message(index + 1, ordinal, flexibility, request) {
81 Ok(send_future) => {
82 Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
83 }
84 Err(e) => {
85 self.inner.responses.lock().unwrap().free(index);
86 Err(e)
87 }
88 }
89 }
90
91 fn send_message<W>(
92 &self,
93 txid: u32,
94 ordinal: u64,
95 flexibility: Flexibility,
96 message: impl Encode<W, T::SendBuffer>,
97 ) -> Result<SendFuture<'_, T>, EncodeError>
98 where
99 W: Constrained<Constraint = ()> + Wire,
100 {
101 self.inner.connection.send_message(|buffer| {
102 encode_header::<T>(buffer, txid, ordinal, flexibility)?;
103 buffer.encode_next(message, ())
104 })
105 }
106}
107
108impl<T: Transport> Clone for Client<T> {
109 fn clone(&self) -> Self {
110 Self { inner: self.inner.clone() }
111 }
112}
113
114pub struct TwoWayResponseFuture<'a, T: Transport> {
116 inner: &'a ClientInner<T>,
117 index: Option<u32>,
118}
119
120impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
121 fn drop(&mut self) {
122 if let Some(index) = self.index {
124 let mut responses = self.inner.responses.lock().unwrap();
125 if responses.get(index).unwrap().cancel() {
126 responses.free(index);
127 }
128 }
129 }
130}
131
132impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
133 type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
134
135 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 let this = Pin::into_inner(self);
137 let Some(index) = this.index else {
138 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
139 };
140
141 let mut responses = this.inner.responses.lock().unwrap();
142 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
143 Ok(ready)
144 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
145 Err(termination_reason)
146 } else {
147 return Poll::Pending;
148 };
149
150 responses.free(index);
151 this.index = None;
152 Poll::Ready(ready)
153 }
154}
155
156#[pin_project(PinnedDrop)]
158pub struct TwoWayRequestFuture<'a, T: Transport> {
159 inner: &'a ClientInner<T>,
160 index: Option<u32>,
161 #[pin]
162 send_future: SendFuture<'a, T>,
163}
164
165#[pinned_drop]
166impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
167 fn drop(self: Pin<&mut Self>) {
168 if let Some(index) = self.index {
169 let mut responses = self.inner.responses.lock().unwrap();
170
171 responses.free(index);
174 }
175 }
176}
177
178impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
179 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
180
181 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
182 let this = self.project();
183
184 let Some(index) = *this.index else {
185 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
186 };
187
188 let result = ready!(this.send_future.poll(cx));
189 *this.index = None;
190 if let Err(error) = result {
191 this.inner.responses.lock().unwrap().free(index);
193 Poll::Ready(Err(error))
194 } else {
195 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
196 }
197 }
198}
199
200pub trait ClientHandler<T: Transport> {
202 fn on_event(
209 &mut self,
210 ordinal: u64,
211 flexibility: Flexibility,
212 buffer: T::RecvBuffer,
213 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
214}
215
216pub struct ClientDispatcher<T: Transport> {
225 inner: Arc<ClientInner<T>>,
226 exclusive: T::Exclusive,
227 is_terminated: bool,
228}
229
230impl<T: Transport> Drop for ClientDispatcher<T> {
231 fn drop(&mut self) {
232 if !self.is_terminated {
233 unsafe {
235 self.terminate(ProtocolError::Stopped);
236 }
237 }
238 }
239}
240
241impl<T: Transport> ClientDispatcher<T> {
242 pub fn new(transport: T) -> Self {
244 let (shared, exclusive) = transport.split();
245 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
246 }
247
248 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
252 unsafe {
254 self.inner.connection.terminate(error);
255 }
256 self.inner.responses.lock().unwrap().wake_all();
257 }
258
259 pub fn client(&self) -> Client<T> {
263 Client { inner: self.inner.clone() }
264 }
265
266 pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
268 where
269 H: ClientHandler<T>,
270 {
271 let error = loop {
277 let result = unsafe { self.run_one(&mut handler).await };
279 if let Err(error) = result {
280 break error;
281 }
282 };
283
284 unsafe {
286 self.terminate(error.clone());
287 }
288 self.is_terminated = true;
289
290 match error {
291 ProtocolError::Stopped => Ok(handler),
294
295 _ => Err(error),
297 }
298 }
299
300 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
304 where
305 H: ClientHandler<T>,
306 {
307 let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
309
310 let (txid, ordinal, flexibility) =
311 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
312
313 if ordinal == EPITAPH_ORDINAL {
314 let epitaph =
315 decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
316 return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
317 } else if txid == 0 {
318 handler.on_event(ordinal, flexibility, buffer).await?;
319 } else {
320 let mut responses = self.inner.responses.lock().unwrap();
321 let locker = responses
322 .get(txid - 1)
323 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
324
325 match locker.write(ordinal, buffer) {
326 Ok(false) => (),
328 Ok(true) => responses.free(txid - 1),
330 Err(LockerError::NotWriteable) => {
331 return Err(ProtocolError::UnrequestedResponse { txid });
332 }
333 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
334 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
335 }
336 }
337 }
338
339 Ok(())
340 }
341}