1use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll, ready};
9
10use fidl_next_codec::{
11 Constrained, Decode, Decoded, DecoderExt, EncodeError, FromWire, IntoNatural, Wire,
12};
13use fidl_next_protocol::Transport;
14use pin_project::pin_project;
15
16use crate::{Error, TwoWayMethod};
17
18#[pin_project(project = TwoWayFutureStateProj, project_replace = TwoWayFutureStateOwn)]
19enum TwoWayFutureState<'a, T: Transport> {
20 EncodeError(EncodeError),
21 SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>),
22 SendingRequest(#[pin] fidl_next_protocol::TwoWayRequestFuture<'a, T>),
23 ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>),
24 ReceivingResponse(#[pin] fidl_next_protocol::TwoWayResponseFuture<'a, T>),
25 DecodeBuffer(T::RecvBuffer),
26 Finished,
27}
28
29macro_rules! impl_two_way_future_state {
30 ($(
31 $variant:ident($ty:ty) => $check:ident $unwrap:ident
32 ),* $(,)?) => {
33 impl<T: Transport> TwoWayFutureState<'_, T> {
34 $(
35 #[allow(dead_code)]
36 fn $check(&self) -> bool {
37 matches!(self, Self::$variant(_))
38 }
39 )*
40 }
41
42 impl<'a, T: Transport> TwoWayFutureStateOwn<'a, T> {
43 $(
44 #[allow(dead_code)]
45 fn $unwrap(self) -> $ty {
46 let Self::$variant(value) = self else {
47 unreachable!()
48 };
49 value
50 }
51 )*
52 }
53 };
54}
55
56impl_two_way_future_state! {
57 EncodeError(EncodeError) => is_encode_error unwrap_encode_error,
58 SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>)
59 => is_send_request unwrap_send_request,
60 ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>)
61 => is_receive_response unwrap_receive_response,
62 DecodeBuffer(T::RecvBuffer) => is_decode_buffer unwrap_decode_buffer,
63}
64
65impl<'a, T: Transport> TwoWayFutureState<'a, T> {
66 fn finish(self: Pin<&mut Self>) -> TwoWayFutureStateOwn<'a, T> {
67 self.project_replace(Self::Finished)
68 }
69
70 fn poll_advance(
71 mut self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 ) -> Poll<Result<(), Error<T::Error>>> {
74 Poll::Ready(match self.as_mut().project() {
75 TwoWayFutureStateProj::EncodeError(_) => {
76 Err(Error::Encode(self.finish().unwrap_encode_error()))
77 }
78 TwoWayFutureStateProj::SendRequest(_) => {
79 let future = self.as_mut().finish().unwrap_send_request();
80 self.project_replace(Self::SendingRequest(future));
81 Ok(())
82 }
83 TwoWayFutureStateProj::SendingRequest(future) => match ready!(future.poll(cx)) {
84 Ok(future) => {
85 self.project_replace(Self::ReceiveResponse(future));
86 Ok(())
87 }
88 Err(error) => {
89 self.finish();
90 Err(Error::Protocol(error))
91 }
92 },
93 TwoWayFutureStateProj::ReceiveResponse(_) => {
94 let future = self.as_mut().finish().unwrap_receive_response();
95 self.project_replace(Self::ReceivingResponse(future));
96 Ok(())
97 }
98 TwoWayFutureStateProj::ReceivingResponse(future) => match ready!(future.poll(cx)) {
99 Ok(buffer) => {
100 self.project_replace(Self::DecodeBuffer(buffer));
101 Ok(())
102 }
103 Err(error) => {
104 self.finish();
105 Err(Error::Protocol(error))
106 }
107 },
108 TwoWayFutureStateProj::DecodeBuffer(_) | TwoWayFutureStateProj::Finished => {
109 panic!("TwoWayFutureState polled after completing");
110 }
111 })
112 }
113
114 fn poll_until(
115 mut self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 is_done: impl Fn(&Self) -> bool,
118 ) -> Poll<Result<TwoWayFutureStateOwn<'a, T>, Error<T::Error>>> {
119 while !is_done(&self) {
120 if let Err(error) = ready!(self.as_mut().poll_advance(cx)) {
121 return Poll::Ready(Err(error));
122 }
123 }
124 Poll::Ready(Ok(self.finish()))
125 }
126}
127
128macro_rules! two_way_futures {
129 ($(
130 $(#[$metas:meta])* $future:ident -> $output:ty
131 where [$($tt:tt)*]
132 {
133 $check:ident => |$state:ident| $expr:expr
134 }
135 ),* $(,)?) => {
136 $(
137 $(#[$metas])*
138 #[must_use = "futures do nothing unless polled"]
139 #[pin_project]
140 pub struct $future<
141 'a,
142 M,
143 T: Transport,
144 > {
145 #[pin]
146 state: TwoWayFutureState<'a, T>,
147 _method: PhantomData<M>,
148 }
149
150 impl<'a, M, T> Future for $future<'a, M, T>
151 where
152 T: Transport,
153 $($tt)*
154 {
155 type Output = Result<$output, Error<T::Error>>;
156
157 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 let $state = ready!(self.project().state.poll_until(
159 cx,
160 TwoWayFutureState::$check,
161 ))?;
162 Poll::Ready(Ok($expr))
163 }
164 }
165 )*
166 }
167}
168
169two_way_futures! {
170 TwoWayFuture -> <M::Response as IntoNatural>::Natural
174 where [
175 M: TwoWayMethod,
176 M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
177 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Owned<'de>>,
178 ]
179 {
180 is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?.take()
181 },
182
183 EncodedTwoWayFuture -> <M::Response as IntoNatural>::Natural
190 where [
191 M: TwoWayMethod,
192 M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
193 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Owned<'de>>,
194 ]
195 {
196 is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?.take()
197 },
198
199 SendTwoWayFuture -> SentTwoWayFuture<'a, M, T>
205 where []
206 {
207 is_receive_response => |state| SentTwoWayFuture {
208 state: TwoWayFutureState::ReceiveResponse(state.unwrap_receive_response()),
209 _method: PhantomData,
210 }
211 },
212
213 SentTwoWayFuture -> <M::Response as IntoNatural>::Natural
220 where [
221 M: TwoWayMethod,
222 M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
223 <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Owned<'de>>,
224 ]
225 {
226 is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?.take()
227 },
228
229 RecvBufferTwoWayFuture -> T::RecvBuffer
235 where []
236 {
237 is_decode_buffer => |state| state.unwrap_decode_buffer()
238 },
239
240 WireTwoWayFuture -> Decoded<M::Response, T::RecvBuffer>
246 where [
247 M: TwoWayMethod,
248 M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
249 ]
250 {
251 is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?
252 }
253}
254
255macro_rules! impl_for_futures {
256 (
257 $($futures:ident)*,
258 $encode:item
259 ) => {
260 $(
261 impl<'a, M, T: Transport> $futures<'a, M, T> {
262 $encode
263 }
264 )*
265 }
266}
267
268impl_for_futures! {
273 TwoWayFuture,
274
275 pub fn encode(self) -> Result<EncodedTwoWayFuture<'a, M, T>, Error<T::Error>> {
279 Ok(EncodedTwoWayFuture {
280 state: match self.state {
281 TwoWayFutureState::EncodeError(error) => return Err(Error::Encode(error)),
282 state => state,
283 },
284 _method: PhantomData,
285 })
286 }
287}
288
289impl_for_futures! {
290 TwoWayFuture EncodedTwoWayFuture,
291
292 pub fn send(self) -> SendTwoWayFuture<'a, M, T> {
296 SendTwoWayFuture {
297 state: self.state,
298 _method: PhantomData,
299 }
300 }
301}
302
303impl_for_futures! {
304 TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
305
306 pub fn recv_buffer(self) -> RecvBufferTwoWayFuture<'a, M, T> {
310 RecvBufferTwoWayFuture {
311 state: self.state,
312 _method: PhantomData,
313 }
314 }
315}
316
317impl_for_futures! {
318 TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
319
320 pub fn wire(self) -> WireTwoWayFuture<'a, M, T> {
325 WireTwoWayFuture {
326 state: self.state,
327 _method: PhantomData,
328 }
329 }
330}
331
332impl<'a, M, T: Transport> TwoWayFuture<'a, M, T> {
333 pub fn from_untyped(
335 result: Result<fidl_next_protocol::TwoWayRequestFuture<'a, T>, EncodeError>,
336 ) -> Self {
337 Self {
338 state: match result {
339 Ok(future) => TwoWayFutureState::SendRequest(future),
340 Err(error) => TwoWayFutureState::EncodeError(error),
341 },
342 _method: PhantomData,
343 }
344 }
345}