fidl_next_bind/future/
two_way.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll, ready};
9
10use fidl_next_codec::{
11    Constrained, Decode, Decoded, DecoderExt, EncodeError, FromWire, IntoNatural, Wire,
12};
13use fidl_next_protocol::Transport;
14use pin_project::pin_project;
15
16use crate::{Error, TwoWayMethod};
17
18#[pin_project(project = TwoWayFutureStateProj, project_replace = TwoWayFutureStateOwn)]
19enum TwoWayFutureState<'a, T: Transport> {
20    EncodeError(EncodeError),
21    SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>),
22    SendingRequest(#[pin] fidl_next_protocol::TwoWayRequestFuture<'a, T>),
23    ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>),
24    ReceivingResponse(#[pin] fidl_next_protocol::TwoWayResponseFuture<'a, T>),
25    DecodeBuffer(T::RecvBuffer),
26    Finished,
27}
28
29macro_rules! impl_two_way_future_state {
30    ($(
31        $variant:ident($ty:ty) => $check:ident $unwrap:ident
32    ),* $(,)?) => {
33        impl<T: Transport> TwoWayFutureState<'_, T> {
34            $(
35                #[allow(dead_code)]
36                fn $check(&self) -> bool {
37                    matches!(self, Self::$variant(_))
38                }
39            )*
40        }
41
42        impl<'a, T: Transport> TwoWayFutureStateOwn<'a, T> {
43            $(
44                #[allow(dead_code)]
45                fn $unwrap(self) -> $ty {
46                    let Self::$variant(value) = self else {
47                        unreachable!()
48                    };
49                    value
50                }
51            )*
52        }
53    };
54}
55
56impl_two_way_future_state! {
57    EncodeError(EncodeError) => is_encode_error unwrap_encode_error,
58    SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>)
59        => is_send_request unwrap_send_request,
60    ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>)
61        => is_receive_response unwrap_receive_response,
62    DecodeBuffer(T::RecvBuffer) => is_decode_buffer unwrap_decode_buffer,
63}
64
65impl<'a, T: Transport> TwoWayFutureState<'a, T> {
66    fn finish(self: Pin<&mut Self>) -> TwoWayFutureStateOwn<'a, T> {
67        self.project_replace(Self::Finished)
68    }
69
70    fn poll_advance(
71        mut self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73    ) -> Poll<Result<(), Error<T::Error>>> {
74        Poll::Ready(match self.as_mut().project() {
75            TwoWayFutureStateProj::EncodeError(_) => {
76                Err(Error::Encode(self.finish().unwrap_encode_error()))
77            }
78            TwoWayFutureStateProj::SendRequest(_) => {
79                let future = self.as_mut().finish().unwrap_send_request();
80                self.project_replace(Self::SendingRequest(future));
81                Ok(())
82            }
83            TwoWayFutureStateProj::SendingRequest(future) => match ready!(future.poll(cx)) {
84                Ok(future) => {
85                    self.project_replace(Self::ReceiveResponse(future));
86                    Ok(())
87                }
88                Err(error) => {
89                    self.finish();
90                    Err(Error::Protocol(error))
91                }
92            },
93            TwoWayFutureStateProj::ReceiveResponse(_) => {
94                let future = self.as_mut().finish().unwrap_receive_response();
95                self.project_replace(Self::ReceivingResponse(future));
96                Ok(())
97            }
98            TwoWayFutureStateProj::ReceivingResponse(future) => match ready!(future.poll(cx)) {
99                Ok(buffer) => {
100                    self.project_replace(Self::DecodeBuffer(buffer));
101                    Ok(())
102                }
103                Err(error) => {
104                    self.finish();
105                    Err(Error::Protocol(error))
106                }
107            },
108            TwoWayFutureStateProj::DecodeBuffer(_) | TwoWayFutureStateProj::Finished => {
109                panic!("TwoWayFutureState polled after completing");
110            }
111        })
112    }
113
114    fn poll_until(
115        mut self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117        is_done: impl Fn(&Self) -> bool,
118    ) -> Poll<Result<TwoWayFutureStateOwn<'a, T>, Error<T::Error>>> {
119        while !is_done(&self) {
120            if let Err(error) = ready!(self.as_mut().poll_advance(cx)) {
121                return Poll::Ready(Err(error));
122            }
123        }
124        Poll::Ready(Ok(self.finish()))
125    }
126}
127
128macro_rules! two_way_futures {
129    ($(
130        $(#[$metas:meta])* $future:ident -> $output:ty
131        where [$($tt:tt)*]
132        {
133            $check:ident => |$state:ident| $expr:expr
134        }
135    ),* $(,)?) => {
136        $(
137            $(#[$metas])*
138            #[must_use = "futures do nothing unless polled"]
139            #[pin_project]
140            pub struct $future<
141                'a,
142                M,
143                T: Transport,
144            > {
145                #[pin]
146                state: TwoWayFutureState<'a, T>,
147                _method: PhantomData<M>,
148            }
149
150            impl<'a, M, T> Future for $future<'a, M, T>
151            where
152                T: Transport,
153                $($tt)*
154            {
155                type Output = Result<$output, Error<T::Error>>;
156
157                fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158                    let $state = ready!(self.project().state.poll_until(
159                        cx,
160                        TwoWayFutureState::$check,
161                    ))?;
162                    Poll::Ready(Ok($expr))
163                }
164            }
165        )*
166    }
167}
168
169two_way_futures! {
170    // `foo().await`
171
172    /// A future which performs a two-way FIDL method call.
173    TwoWayFuture -> <M::Response as IntoNatural>::Natural
174    where [
175        M: TwoWayMethod,
176        M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
177        <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Owned<'de>>,
178    ]
179    {
180        is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?.take()
181    },
182
183    // `foo().encode()?.await`
184
185    /// A future which performs a two-way FIDL method call.
186    ///
187    /// This future has already been successfully encoded. It still needs to be
188    /// sent and a response needs to be received.
189    EncodedTwoWayFuture -> <M::Response as IntoNatural>::Natural
190    where [
191        M: TwoWayMethod,
192        M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
193        <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Owned<'de>>,
194    ]
195    {
196        is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?.take()
197    },
198
199    // `foo().send().await`
200
201    /// A future which sends a two-way FIDL method call.
202    ///
203    /// This future returns another future which completes the FIDL call.
204    SendTwoWayFuture -> SentTwoWayFuture<'a, M, T>
205    where []
206    {
207        is_receive_response => |state| SentTwoWayFuture {
208            state: TwoWayFutureState::ReceiveResponse(state.unwrap_receive_response()),
209            _method: PhantomData,
210        }
211    },
212
213    // `foo().send().await?.await`
214
215    /// A future which performs a two-way FIDL method call.
216    ///
217    /// This future has already been successfully encoded and sent. A response
218    /// still needs to be received.
219    SentTwoWayFuture -> <M::Response as IntoNatural>::Natural
220    where [
221        M: TwoWayMethod,
222        M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
223        <M::Response as IntoNatural>::Natural: for<'de> FromWire<<M::Response as Wire>::Owned<'de>>,
224    ]
225    {
226        is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?.take()
227    },
228
229    // `foo().recv_buffer().await`
230
231    /// A future which receives a two-way FIDL method call as a `RecvBuffer`.
232    ///
233    /// This future returns the response buffer without decoding it first.
234    RecvBufferTwoWayFuture -> T::RecvBuffer
235    where []
236    {
237        is_decode_buffer => |state| state.unwrap_decode_buffer()
238    },
239
240    // `foo().wire().await`
241
242    /// A future which decodes a two-way FIDL method call as a wire type.
243    ///
244    /// This future returns the decoded response.
245    WireTwoWayFuture -> Decoded<M::Response, T::RecvBuffer>
246    where [
247        M: TwoWayMethod,
248        M::Response: Decode<T::RecvBuffer> + Constrained<Constraint = ()> + IntoNatural,
249    ]
250    {
251        is_decode_buffer => |state| state.unwrap_decode_buffer().decode::<M::Response>()?
252    }
253}
254
255macro_rules! impl_for_futures {
256    (
257        $($futures:ident)*,
258        $encode:item
259    ) => {
260        $(
261            impl<'a, M, T: Transport> $futures<'a, M, T> {
262                $encode
263            }
264        )*
265    }
266}
267
268// Each of these methods marks a point where the next `.await` will run message
269// processing until. By default, message processing runs all the way to the end
270// of the pipeline, returning a natural type.
271
272impl_for_futures! {
273    TwoWayFuture,
274
275    /// Encodes the two-way message.
276    ///
277    /// Returns a future which completes the request, or an error if it failed.
278    pub fn encode(self) -> Result<EncodedTwoWayFuture<'a, M, T>, Error<T::Error>> {
279        Ok(EncodedTwoWayFuture {
280            state: match self.state {
281                TwoWayFutureState::EncodeError(error) => return Err(Error::Encode(error)),
282                state => state,
283            },
284            _method: PhantomData,
285        })
286    }
287}
288
289impl_for_futures! {
290    TwoWayFuture EncodedTwoWayFuture,
291
292    /// Sends the two-way message.
293    ///
294    /// Returns a future which completes the request, or an error if it failed.
295    pub fn send(self) -> SendTwoWayFuture<'a, M, T> {
296        SendTwoWayFuture {
297            state: self.state,
298            _method: PhantomData,
299        }
300    }
301}
302
303impl_for_futures! {
304    TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
305
306    /// Receives the response to the two-way message.
307    ///
308    /// Returns the response buffer, or an error if it failed.
309    pub fn recv_buffer(self) -> RecvBufferTwoWayFuture<'a, M, T> {
310        RecvBufferTwoWayFuture {
311            state: self.state,
312            _method: PhantomData,
313        }
314    }
315}
316
317impl_for_futures! {
318    TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
319
320    /// Receives the response to the two-way message and decodes it as a wire
321    /// type.
322    ///
323    /// Returns the decoded response, or an error if it failed.
324    pub fn wire(self) -> WireTwoWayFuture<'a, M, T> {
325        WireTwoWayFuture {
326            state: self.state,
327            _method: PhantomData,
328        }
329    }
330}
331
332impl<'a, M, T: Transport> TwoWayFuture<'a, M, T> {
333    /// Returns a `TwoWayFuture` wrapping the given result.
334    pub fn from_untyped(
335        result: Result<fidl_next_protocol::TwoWayRequestFuture<'a, T>, EncodeError>,
336    ) -> Self {
337        Self {
338            state: match result {
339                Ok(future) => TwoWayFutureState::SendRequest(future),
340                Err(error) => TwoWayFutureState::EncodeError(error),
341            },
342            _method: PhantomData,
343        }
344    }
345}