1use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll, ready};
9
10use fidl_next_codec::{
11 AsDecoder, AsDecoderExt as _, Constrained, Decode, Decoded, EncodeError, FromWire, IntoNatural,
12 Wire,
13};
14use fidl_next_protocol::{Body, Transport};
15use pin_project::pin_project;
16
17use crate::{Error, TwoWayMethod};
18
19#[pin_project(project = TwoWayFutureStateProj, project_replace = TwoWayFutureStateOwn)]
20enum TwoWayFutureState<'a, T: Transport> {
21 EncodeError(EncodeError),
22 SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>),
23 SendingRequest(#[pin] fidl_next_protocol::TwoWayRequestFuture<'a, T>),
24 ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>),
25 ReceivingResponse(#[pin] fidl_next_protocol::TwoWayResponseFuture<'a, T>),
26 DecodeBuffer(Body<T>),
27 Finished,
28}
29
30macro_rules! impl_two_way_future_state {
31 ($(
32 $variant:ident($ty:ty) => $check:ident $unwrap:ident
33 ),* $(,)?) => {
34 impl<T: Transport> TwoWayFutureState<'_, T> {
35 $(
36 #[allow(dead_code)]
37 fn $check(&self) -> bool {
38 matches!(self, Self::$variant(_))
39 }
40 )*
41 }
42
43 impl<'a, T: Transport> TwoWayFutureStateOwn<'a, T> {
44 $(
45 #[allow(dead_code)]
46 fn $unwrap(self) -> $ty {
47 let Self::$variant(value) = self else {
48 unreachable!()
49 };
50 value
51 }
52 )*
53 }
54 };
55}
56
57impl_two_way_future_state! {
58 EncodeError(EncodeError) => is_encode_error unwrap_encode_error,
59 SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>)
60 => is_send_request unwrap_send_request,
61 ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>)
62 => is_receive_response unwrap_receive_response,
63 DecodeBuffer(Body<T>) => is_decode_buffer unwrap_decode_buffer,
64}
65
66impl<'a, T: Transport> TwoWayFutureState<'a, T> {
67 fn finish(self: Pin<&mut Self>) -> TwoWayFutureStateOwn<'a, T> {
68 self.project_replace(Self::Finished)
69 }
70
71 fn poll_advance(
72 mut self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 ) -> Poll<Result<(), Error<T::Error>>> {
75 Poll::Ready(match self.as_mut().project() {
76 TwoWayFutureStateProj::EncodeError(_) => {
77 Err(Error::Encode(self.finish().unwrap_encode_error()))
78 }
79 TwoWayFutureStateProj::SendRequest(_) => {
80 let future = self.as_mut().finish().unwrap_send_request();
81 self.project_replace(Self::SendingRequest(future));
82 Ok(())
83 }
84 TwoWayFutureStateProj::SendingRequest(future) => match ready!(future.poll(cx)) {
85 Ok(future) => {
86 self.project_replace(Self::ReceiveResponse(future));
87 Ok(())
88 }
89 Err(error) => {
90 self.finish();
91 Err(Error::Protocol(error))
92 }
93 },
94 TwoWayFutureStateProj::ReceiveResponse(_) => {
95 let future = self.as_mut().finish().unwrap_receive_response();
96 self.project_replace(Self::ReceivingResponse(future));
97 Ok(())
98 }
99 TwoWayFutureStateProj::ReceivingResponse(future) => match ready!(future.poll(cx)) {
100 Ok(body) => {
101 self.project_replace(Self::DecodeBuffer(body));
102 Ok(())
103 }
104 Err(error) => {
105 self.finish();
106 Err(Error::Protocol(error))
107 }
108 },
109 TwoWayFutureStateProj::DecodeBuffer(_) | TwoWayFutureStateProj::Finished => {
110 panic!("TwoWayFutureState polled after completing");
111 }
112 })
113 }
114
115 fn poll_until(
116 mut self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 is_done: impl Fn(&Self) -> bool,
119 ) -> Poll<Result<TwoWayFutureStateOwn<'a, T>, Error<T::Error>>> {
120 while !is_done(&self) {
121 if let Err(error) = ready!(self.as_mut().poll_advance(cx)) {
122 return Poll::Ready(Err(error));
123 }
124 }
125 Poll::Ready(Ok(self.finish()))
126 }
127}
128
129macro_rules! two_way_futures {
130 ($(
131 $(#[$metas:meta])* $future:ident -> $output:ty
132 where [$($tt:tt)*]
133 {
134 $check:ident => |$state:ident| $expr:expr
135 }
136 ),* $(,)?) => {
137 $(
138 $(#[$metas])*
139 #[must_use = "futures do nothing unless polled"]
140 #[pin_project]
141 pub struct $future<
142 'a,
143 M,
144 T: Transport,
145 > {
146 #[pin]
147 state: TwoWayFutureState<'a, T>,
148 _method: PhantomData<M>,
149 }
150
151 impl<'a, M, T> Future for $future<'a, M, T>
152 where
153 T: Transport,
154 $($tt)*
155 {
156 type Output = Result<$output, Error<T::Error>>;
157
158 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 let $state = ready!(self.project().state.poll_until(
160 cx,
161 TwoWayFutureState::$check,
162 ))?;
163 Poll::Ready(Ok($expr))
164 }
165 }
166 )*
167 }
168}
169
170two_way_futures! {
171 TwoWayFuture -> <M::Response as IntoNatural>::Natural
175 where [
176 M: TwoWayMethod,
177 M::Response: Wire<Constraint = ()> + IntoNatural,
178 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
179 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Narrowed<'de>>,
180 ]
181 {
182 is_decode_buffer => |state| state.unwrap_decode_buffer().into_decoded::<M::Response>()?.take()
183 },
184
185 EncodedTwoWayFuture -> <M::Response as IntoNatural>::Natural
192 where [
193 M: TwoWayMethod,
194 M::Response: Wire<Constraint = ()> + IntoNatural,
195 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
196 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Narrowed<'de>>,
197 ]
198 {
199 is_decode_buffer => |state| state.unwrap_decode_buffer().into_decoded::<M::Response>()?.take()
200 },
201
202 SendTwoWayFuture -> SentTwoWayFuture<'a, M, T>
208 where []
209 {
210 is_receive_response => |state| SentTwoWayFuture {
211 state: TwoWayFutureState::ReceiveResponse(state.unwrap_receive_response()),
212 _method: PhantomData,
213 }
214 },
215
216 SentTwoWayFuture -> <M::Response as IntoNatural>::Natural
223 where [
224 M: TwoWayMethod,
225 M::Response: Wire<Constraint = ()> + IntoNatural,
226 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
227 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Narrowed<'de>>,
228 ]
229 {
230 is_decode_buffer => |state| state.unwrap_decode_buffer().into_decoded::<M::Response>()?.take()
231 },
232
233 RecvBufferTwoWayFuture -> Body<T>
239 where []
240 {
241 is_decode_buffer => |state| state.unwrap_decode_buffer()
242 },
243
244 WireTwoWayFuture -> Decoded<M::Response, Body<T>>
250 where [
251 M: TwoWayMethod,
252 M::Response: Wire<Constraint = ()> + IntoNatural,
253 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
254 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Narrowed<'de>>,
255 ]
256 {
257 is_decode_buffer => |state| state.unwrap_decode_buffer().into_decoded()?
258 }
259}
260
261macro_rules! impl_for_futures {
262 (
263 $($futures:ident)*,
264 $encode:item
265 ) => {
266 $(
267 impl<'a, M, T: Transport> $futures<'a, M, T> {
268 $encode
269 }
270 )*
271 }
272}
273
274impl_for_futures! {
279 TwoWayFuture,
280
281 pub fn encode(self) -> Result<EncodedTwoWayFuture<'a, M, T>, Error<T::Error>> {
285 Ok(EncodedTwoWayFuture {
286 state: match self.state {
287 TwoWayFutureState::EncodeError(error) => return Err(Error::Encode(error)),
288 state => state,
289 },
290 _method: PhantomData,
291 })
292 }
293}
294
295impl_for_futures! {
296 TwoWayFuture EncodedTwoWayFuture,
297
298 pub fn send(self) -> SendTwoWayFuture<'a, M, T> {
302 SendTwoWayFuture {
303 state: self.state,
304 _method: PhantomData,
305 }
306 }
307}
308
309impl_for_futures! {
310 TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
311
312 pub fn recv_buffer(self) -> RecvBufferTwoWayFuture<'a, M, T> {
316 RecvBufferTwoWayFuture {
317 state: self.state,
318 _method: PhantomData,
319 }
320 }
321}
322
323impl_for_futures! {
324 TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
325
326 pub fn wire(self) -> WireTwoWayFuture<'a, M, T> {
331 WireTwoWayFuture {
332 state: self.state,
333 _method: PhantomData,
334 }
335 }
336}
337
338impl<'a, M, T: Transport> TwoWayFuture<'a, M, T> {
339 pub fn from_untyped(
341 result: Result<fidl_next_protocol::TwoWayRequestFuture<'a, T>, EncodeError>,
342 ) -> Self {
343 Self {
344 state: match result {
345 Ok(future) => TwoWayFutureState::SendRequest(future),
346 Err(error) => TwoWayFutureState::EncodeError(error),
347 },
348 _method: PhantomData,
349 }
350 }
351}