fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_constants::EPITAPH_ORDINAL;
12use fidl_next_codec::{AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire};
13use pin_project::{pin_project, pinned_drop};
14
15use crate::concurrency::sync::{Arc, Mutex};
16use crate::endpoints::connection::Connection;
17use crate::endpoints::lockers::{LockerError, Lockers};
18use crate::wire::{Epitaph, MessageHeader};
19use crate::{Body, Flexibility, ProtocolError, SendFuture, Transport};
20
21struct ClientInner<T: Transport> {
22 connection: Connection<T>,
23 responses: Mutex<Lockers<Body<T>>>,
24}
25
26impl<T: Transport> ClientInner<T> {
27 fn new(shared: T::Shared) -> Self {
28 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
29 }
30}
31
32pub struct Client<T: Transport> {
34 inner: Arc<ClientInner<T>>,
35}
36
37impl<T: Transport> Drop for Client<T> {
38 fn drop(&mut self) {
39 if Arc::strong_count(&self.inner) == 2 {
40 self.close();
43 }
44 }
45}
46
47impl<T: Transport> Client<T> {
48 pub fn close(&self) {
50 self.inner.connection.stop();
51 }
52
53 pub fn send_one_way<W>(
55 &self,
56 ordinal: u64,
57 flexibility: Flexibility,
58 request: impl Encode<W, T::SendBuffer>,
59 ) -> Result<SendFuture<'_, T>, EncodeError>
60 where
61 W: Wire<Constraint = ()>,
62 {
63 self.send_message(0, ordinal, flexibility, request)
64 }
65
66 pub fn send_two_way<W>(
68 &self,
69 ordinal: u64,
70 flexibility: Flexibility,
71 request: impl Encode<W, T::SendBuffer>,
72 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
73 where
74 W: Wire<Constraint = ()>,
75 {
76 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
77
78 match self.send_message(index + 1, ordinal, flexibility, request) {
80 Ok(send_future) => {
81 Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
82 }
83 Err(e) => {
84 self.inner.responses.lock().unwrap().free(index);
85 Err(e)
86 }
87 }
88 }
89
90 fn send_message<W>(
91 &self,
92 txid: u32,
93 ordinal: u64,
94 flexibility: Flexibility,
95 message: impl Encode<W, T::SendBuffer>,
96 ) -> Result<SendFuture<'_, T>, EncodeError>
97 where
98 W: Wire<Constraint = ()>,
99 {
100 self.inner.connection.send_message(|buffer| {
101 buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
102 buffer.encode_next(message)
103 })
104 }
105}
106
107impl<T: Transport> Clone for Client<T> {
108 fn clone(&self) -> Self {
109 Self { inner: self.inner.clone() }
110 }
111}
112
113pub struct TwoWayResponseFuture<'a, T: Transport> {
115 inner: &'a ClientInner<T>,
116 index: Option<u32>,
117}
118
119impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
120 fn drop(&mut self) {
121 if let Some(index) = self.index {
123 let mut responses = self.inner.responses.lock().unwrap();
124 if responses.get(index).unwrap().cancel() {
125 responses.free(index);
126 }
127 }
128 }
129}
130
131impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
132 type Output = Result<Body<T>, ProtocolError<T::Error>>;
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 let this = Pin::into_inner(self);
136 let Some(index) = this.index else {
137 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
138 };
139
140 let mut responses = this.inner.responses.lock().unwrap();
141 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
142 Ok(ready)
143 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
144 Err(termination_reason)
145 } else {
146 return Poll::Pending;
147 };
148
149 responses.free(index);
150 this.index = None;
151 Poll::Ready(ready)
152 }
153}
154
155#[pin_project(PinnedDrop)]
157pub struct TwoWayRequestFuture<'a, T: Transport> {
158 inner: &'a ClientInner<T>,
159 index: Option<u32>,
160 #[pin]
161 send_future: SendFuture<'a, T>,
162}
163
164#[pinned_drop]
165impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
166 fn drop(self: Pin<&mut Self>) {
167 if let Some(index) = self.index {
168 let mut responses = self.inner.responses.lock().unwrap();
169
170 responses.free(index);
173 }
174 }
175}
176
177impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
178 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
179
180 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181 let this = self.project();
182
183 let Some(index) = *this.index else {
184 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
185 };
186
187 let result = ready!(this.send_future.poll(cx));
188 *this.index = None;
189 if let Err(error) = result {
190 this.inner.responses.lock().unwrap().free(index);
192 Poll::Ready(Err(error))
193 } else {
194 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
195 }
196 }
197}
198
199pub trait LocalClientHandler<T: Transport> {
204 fn on_event(
208 &mut self,
209 ordinal: u64,
210 flexibility: Flexibility,
211 body: Body<T>,
212 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>>;
213}
214
215pub trait ClientHandler<T: Transport>: Send {
217 fn on_event(
223 &mut self,
224 ordinal: u64,
225 flexibility: Flexibility,
226 body: Body<T>,
227 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
228}
229
230#[repr(transparent)]
232pub struct ClientHandlerToLocalAdapter<H>(H);
233
234impl<T, H> LocalClientHandler<T> for ClientHandlerToLocalAdapter<H>
235where
236 T: Transport,
237 H: ClientHandler<T>,
238{
239 #[inline]
240 fn on_event(
241 &mut self,
242 ordinal: u64,
243 flexibility: Flexibility,
244 body: Body<T>,
245 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> {
246 self.0.on_event(ordinal, flexibility, body)
247 }
248}
249
250pub struct ClientDispatcher<T: Transport> {
259 inner: Arc<ClientInner<T>>,
260 exclusive: T::Exclusive,
261 is_terminated: bool,
262}
263
264impl<T: Transport> Drop for ClientDispatcher<T> {
265 fn drop(&mut self) {
266 if !self.is_terminated {
267 unsafe {
269 self.terminate(ProtocolError::Stopped);
270 }
271 }
272 }
273}
274
275impl<T: Transport> ClientDispatcher<T> {
276 pub fn new(transport: T) -> Self {
278 let (shared, exclusive) = transport.split();
279 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
280 }
281
282 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
286 unsafe {
288 self.inner.connection.terminate(error);
289 }
290 self.inner.responses.lock().unwrap().wake_all();
291 }
292
293 pub fn client(&self) -> Client<T> {
297 Client { inner: self.inner.clone() }
298 }
299
300 pub async fn run<H>(self, handler: H) -> Result<H, ProtocolError<T::Error>>
302 where
303 H: ClientHandler<T>,
304 {
305 self.run_local(ClientHandlerToLocalAdapter(handler)).await.map(|adapter| adapter.0)
308 }
309
310 pub async fn run_local<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
312 where
313 H: LocalClientHandler<T>,
314 {
315 let error = loop {
321 let result = unsafe { self.run_one(&mut handler).await };
323 if let Err(error) = result {
324 break error;
325 }
326 };
327
328 unsafe {
330 self.terminate(error.clone());
331 }
332 self.is_terminated = true;
333
334 match error {
335 ProtocolError::Stopped => Ok(handler),
338
339 _ => Err(error),
341 }
342 }
343
344 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
348 where
349 H: LocalClientHandler<T>,
350 {
351 let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
353
354 let header = {
366 let mut decoder = buffer.as_decoder();
367
368 let header = decoder
369 .decode_prefix::<MessageHeader>()
370 .map_err(ProtocolError::InvalidMessageHeader)?;
371
372 if header.ordinal == EPITAPH_ORDINAL {
376 let epitaph =
377 decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
378 return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
379 }
380
381 header
382 };
383
384 if header.txid == 0 {
385 handler.on_event(*header.ordinal, header.flexibility(), Body::new(buffer)).await?;
386 } else {
387 let mut responses = self.inner.responses.lock().unwrap();
388 let locker = responses
389 .get(*header.txid - 1)
390 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid: *header.txid })?;
391
392 match locker.write(*header.ordinal, Body::new(buffer)) {
393 Ok(false) => (),
395 Ok(true) => responses.free(*header.txid - 1),
397 Err(LockerError::NotWriteable) => {
398 return Err(ProtocolError::UnrequestedResponse { txid: *header.txid });
399 }
400 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
401 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
402 }
403 }
404 }
405
406 Ok(())
407 }
408}