fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_constants::EPITAPH_ORDINAL;
12use fidl_next_codec::{AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire};
13use pin_project::{pin_project, pinned_drop};
14
15use crate::concurrency::sync::{Arc, Mutex};
16use crate::endpoints::connection::Connection;
17use crate::endpoints::lockers::{LockerError, Lockers};
18use crate::wire::{Epitaph, MessageHeader};
19use crate::{Body, Flexibility, ProtocolError, SendFuture, Transport};
20
21struct ClientInner<T: Transport> {
22 connection: Connection<T>,
23 responses: Mutex<Lockers<Body<T>>>,
24}
25
26impl<T: Transport> ClientInner<T> {
27 fn new(shared: T::Shared) -> Self {
28 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
29 }
30}
31
32pub struct Client<T: Transport> {
34 inner: Arc<ClientInner<T>>,
35}
36
37impl<T: Transport> Drop for Client<T> {
38 fn drop(&mut self) {
39 if Arc::strong_count(&self.inner) == 2 {
40 self.close();
43 }
44 }
45}
46
47impl<T: Transport> Client<T> {
48 pub fn close(&self) {
50 self.inner.connection.stop();
51 }
52
53 pub fn send_one_way<W>(
55 &self,
56 ordinal: u64,
57 flexibility: Flexibility,
58 request: impl Encode<W, T::SendBuffer>,
59 ) -> Result<SendFuture<'_, T>, EncodeError>
60 where
61 W: Wire<Constraint = ()>,
62 {
63 self.send_message(0, ordinal, flexibility, request)
64 }
65
66 pub fn send_two_way<W>(
68 &self,
69 ordinal: u64,
70 flexibility: Flexibility,
71 request: impl Encode<W, T::SendBuffer>,
72 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
73 where
74 W: Wire<Constraint = ()>,
75 {
76 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
77
78 match self.send_message(index + 1, ordinal, flexibility, request) {
80 Ok(send_future) => {
81 Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
82 }
83 Err(e) => {
84 self.inner.responses.lock().unwrap().free(index);
85 Err(e)
86 }
87 }
88 }
89
90 fn send_message<W>(
91 &self,
92 txid: u32,
93 ordinal: u64,
94 flexibility: Flexibility,
95 message: impl Encode<W, T::SendBuffer>,
96 ) -> Result<SendFuture<'_, T>, EncodeError>
97 where
98 W: Wire<Constraint = ()>,
99 {
100 self.inner.connection.send_message(|buffer| {
101 buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
102 buffer.encode_next(message)
103 })
104 }
105}
106
107impl<T: Transport> Clone for Client<T> {
108 fn clone(&self) -> Self {
109 Self { inner: self.inner.clone() }
110 }
111}
112
113pub struct TwoWayResponseFuture<'a, T: Transport> {
115 inner: &'a ClientInner<T>,
116 index: Option<u32>,
117}
118
119impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
120 fn drop(&mut self) {
121 if let Some(index) = self.index {
123 let mut responses = self.inner.responses.lock().unwrap();
124 if responses.get(index).unwrap().cancel() {
125 responses.free(index);
126 }
127 }
128 }
129}
130
131impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
132 type Output = Result<Body<T>, ProtocolError<T::Error>>;
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 let this = Pin::into_inner(self);
136 let Some(index) = this.index else {
137 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
138 };
139
140 let mut responses = this.inner.responses.lock().unwrap();
141 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
142 Ok(ready)
143 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
144 Err(termination_reason)
145 } else {
146 return Poll::Pending;
147 };
148
149 responses.free(index);
150 this.index = None;
151 Poll::Ready(ready)
152 }
153}
154
155#[pin_project(PinnedDrop)]
157pub struct TwoWayRequestFuture<'a, T: Transport> {
158 inner: &'a ClientInner<T>,
159 index: Option<u32>,
160 #[pin]
161 send_future: SendFuture<'a, T>,
162}
163
164#[pinned_drop]
165impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
166 fn drop(self: Pin<&mut Self>) {
167 if let Some(index) = self.index {
168 let mut responses = self.inner.responses.lock().unwrap();
169
170 responses.free(index);
173 }
174 }
175}
176
177impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
178 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
179
180 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181 let this = self.project();
182
183 let Some(index) = *this.index else {
184 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
185 };
186
187 let result = ready!(this.send_future.poll(cx));
188 *this.index = None;
189 if let Err(error) = result {
190 this.inner.responses.lock().unwrap().free(index);
192 Poll::Ready(Err(error))
193 } else {
194 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
195 }
196 }
197}
198
199pub trait ClientHandler<T: Transport> {
201 fn on_event(
208 &mut self,
209 ordinal: u64,
210 flexibility: Flexibility,
211 body: Body<T>,
212 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
213}
214
215pub struct ClientDispatcher<T: Transport> {
224 inner: Arc<ClientInner<T>>,
225 exclusive: T::Exclusive,
226 is_terminated: bool,
227}
228
229impl<T: Transport> Drop for ClientDispatcher<T> {
230 fn drop(&mut self) {
231 if !self.is_terminated {
232 unsafe {
234 self.terminate(ProtocolError::Stopped);
235 }
236 }
237 }
238}
239
240impl<T: Transport> ClientDispatcher<T> {
241 pub fn new(transport: T) -> Self {
243 let (shared, exclusive) = transport.split();
244 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
245 }
246
247 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
251 unsafe {
253 self.inner.connection.terminate(error);
254 }
255 self.inner.responses.lock().unwrap().wake_all();
256 }
257
258 pub fn client(&self) -> Client<T> {
262 Client { inner: self.inner.clone() }
263 }
264
265 pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
267 where
268 H: ClientHandler<T>,
269 {
270 let error = loop {
276 let result = unsafe { self.run_one(&mut handler).await };
278 if let Err(error) = result {
279 break error;
280 }
281 };
282
283 unsafe {
285 self.terminate(error.clone());
286 }
287 self.is_terminated = true;
288
289 match error {
290 ProtocolError::Stopped => Ok(handler),
293
294 _ => Err(error),
296 }
297 }
298
299 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
303 where
304 H: ClientHandler<T>,
305 {
306 let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
308
309 let header = {
321 let mut decoder = buffer.as_decoder();
322
323 let header = decoder
324 .decode_prefix::<MessageHeader>()
325 .map_err(ProtocolError::InvalidMessageHeader)?;
326
327 if header.ordinal == EPITAPH_ORDINAL {
331 let epitaph =
332 decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
333 return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
334 }
335
336 header
337 };
338
339 if header.txid == 0 {
340 handler.on_event(*header.ordinal, header.flexibility(), Body::new(buffer)).await?;
341 } else {
342 let mut responses = self.inner.responses.lock().unwrap();
343 let locker = responses
344 .get(*header.txid - 1)
345 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid: *header.txid })?;
346
347 match locker.write(*header.ordinal, Body::new(buffer)) {
348 Ok(false) => (),
350 Ok(true) => responses.free(*header.txid - 1),
352 Err(LockerError::NotWriteable) => {
353 return Err(ProtocolError::UnrequestedResponse { txid: *header.txid });
354 }
355 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
356 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
357 }
358 }
359 }
360
361 Ok(())
362 }
363}