fidl_next_protocol/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10use std::sync::{Arc, Mutex};
11
12use fidl_next_codec::{Encode, EncodeError, EncoderExt};
13
14use crate::lockers::Lockers;
15use crate::{decode_header, encode_header, ProtocolError, SendFuture, Transport, TransportExt};
16
17use super::lockers::LockerError;
18
19struct Shared<T: Transport> {
20    responses: Mutex<Lockers<T::RecvBuffer>>,
21}
22
23impl<T: Transport> Shared<T> {
24    fn new() -> Self {
25        Self { responses: Mutex::new(Lockers::new()) }
26    }
27}
28
29/// A sender for a client endpoint.
30pub struct ClientSender<T: Transport> {
31    shared: Arc<Shared<T>>,
32    sender: T::Sender,
33}
34
35impl<T: Transport> ClientSender<T> {
36    /// Closes the channel from the client end.
37    pub fn close(&self) {
38        T::close(&self.sender);
39    }
40
41    /// Send a request.
42    pub fn send_one_way<M>(
43        &self,
44        ordinal: u64,
45        request: &mut M,
46    ) -> Result<SendFuture<'_, T>, EncodeError>
47    where
48        M: Encode<T::SendBuffer>,
49    {
50        self.send_message(0, ordinal, request)
51    }
52
53    /// Send a request and await for a response.
54    pub fn send_two_way<M>(
55        &self,
56        ordinal: u64,
57        request: &mut M,
58    ) -> Result<ResponseFuture<'_, T>, EncodeError>
59    where
60        M: Encode<T::SendBuffer>,
61    {
62        let index = self.shared.responses.lock().unwrap().alloc(ordinal);
63
64        // Send with txid = index + 1 because indices start at 0.
65        match self.send_message(index + 1, ordinal, request) {
66            Ok(future) => Ok(ResponseFuture {
67                shared: &self.shared,
68                index,
69                state: ResponseFutureState::Sending(future),
70            }),
71            Err(e) => {
72                self.shared.responses.lock().unwrap().free(index);
73                Err(e)
74            }
75        }
76    }
77
78    fn send_message<M>(
79        &self,
80        txid: u32,
81        ordinal: u64,
82        message: &mut M,
83    ) -> Result<SendFuture<'_, T>, EncodeError>
84    where
85        M: Encode<T::SendBuffer>,
86    {
87        let mut buffer = T::acquire(&self.sender);
88        encode_header::<T>(&mut buffer, txid, ordinal)?;
89        buffer.encode_next(message)?;
90        Ok(T::send(&self.sender, buffer))
91    }
92}
93
94impl<T: Transport> Clone for ClientSender<T> {
95    fn clone(&self) -> Self {
96        Self { shared: self.shared.clone(), sender: self.sender.clone() }
97    }
98}
99
100enum ResponseFutureState<'a, T: Transport> {
101    Sending(SendFuture<'a, T>),
102    Receiving,
103    // We store the completion state locally so that we can free the locker during poll, instead of
104    // waiting until the future is dropped.
105    Completed,
106}
107
108/// A future for a request pending a response.
109pub struct ResponseFuture<'a, T: Transport> {
110    shared: &'a Shared<T>,
111    index: u32,
112    state: ResponseFutureState<'a, T>,
113}
114
115impl<T: Transport> Drop for ResponseFuture<'_, T> {
116    fn drop(&mut self) {
117        let mut responses = self.shared.responses.lock().unwrap();
118        match self.state {
119            // SAFETY: The future was canceled before it could be sent. The transaction ID was never
120            // used, so it's safe to immediately reuse.
121            ResponseFutureState::Sending(_) => responses.free(self.index),
122            ResponseFutureState::Receiving => {
123                if responses.get(self.index).unwrap().cancel() {
124                    responses.free(self.index);
125                }
126            }
127            // We already freed the slot when we completed.
128            ResponseFutureState::Completed => (),
129        }
130    }
131}
132
133impl<T: Transport> ResponseFuture<'_, T> {
134    fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
135        let mut responses = self.shared.responses.lock().unwrap();
136        if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
137            responses.free(self.index);
138            self.state = ResponseFutureState::Completed;
139            Poll::Ready(Ok(ready))
140        } else {
141            Poll::Pending
142        }
143    }
144}
145
146impl<T: Transport> Future for ResponseFuture<'_, T> {
147    type Output = Result<T::RecvBuffer, T::Error>;
148
149    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        // SAFETY: We treat the state as pinned as long as it is sending.
151        let this = unsafe { Pin::into_inner_unchecked(self) };
152
153        match &mut this.state {
154            ResponseFutureState::Sending(future) => {
155                // SAFETY: Because the state is sending, we always treat its future as pinned.
156                let pinned = unsafe { Pin::new_unchecked(future) };
157                match pinned.poll(cx) {
158                    // The send has not completed yet. Leave the state as sending.
159                    Poll::Pending => Poll::Pending,
160                    Poll::Ready(Ok(())) => {
161                        // The send completed successfully. Change the state to receiving and poll
162                        // for receiving.
163                        this.state = ResponseFutureState::Receiving;
164                        this.poll_receiving(cx)
165                    }
166                    Poll::Ready(Err(e)) => {
167                        // The send completed unsuccessfully. We can safely free the cell and set
168                        // our state to completed.
169
170                        this.shared.responses.lock().unwrap().free(this.index);
171                        this.state = ResponseFutureState::Completed;
172                        Poll::Ready(Err(e))
173                    }
174                }
175            }
176            ResponseFutureState::Receiving => this.poll_receiving(cx),
177            // We could reach here if this future is polled after completion, but that's not
178            // supposed to happen.
179            ResponseFutureState::Completed => unreachable!(),
180        }
181    }
182}
183
184/// A type which handles incoming events for a client.
185pub trait ClientHandler<T: Transport> {
186    /// Handles a received client event.
187    ///
188    /// The client cannot handle more messages until `on_event` completes. If `on_event` may block,
189    /// perform asynchronous work, or take a long time to process a message, it should offload work
190    /// to an async task.
191    fn on_event(&mut self, sender: &ClientSender<T>, ordinal: u64, buffer: T::RecvBuffer);
192}
193
194/// A client for an endpoint.
195///
196/// It must be actively polled to receive events and two-way message responses.
197pub struct Client<T: Transport> {
198    sender: ClientSender<T>,
199    receiver: T::Receiver,
200}
201
202impl<T: Transport> Client<T> {
203    /// Creates a new client from a transport.
204    pub fn new(transport: T) -> Self {
205        let (sender, receiver) = transport.split();
206        let shared = Arc::new(Shared::new());
207        Self { sender: ClientSender { shared, sender }, receiver }
208    }
209
210    /// Returns the sender for the client.
211    pub fn sender(&self) -> &ClientSender<T> {
212        &self.sender
213    }
214
215    /// Runs the client with the provided handler.
216    pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
217    where
218        H: ClientHandler<T>,
219    {
220        let result = self.run_to_completion(&mut handler).await;
221        self.sender.shared.responses.lock().unwrap().wake_all();
222
223        result
224    }
225
226    /// Runs the client with the [`IgnoreEvents`] handler.
227    pub async fn run_sender(&mut self) -> Result<(), ProtocolError<T::Error>> {
228        self.run(IgnoreEvents).await
229    }
230
231    async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
232    where
233        H: ClientHandler<T>,
234    {
235        while let Some(mut buffer) =
236            T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
237        {
238            let (txid, ordinal) =
239                decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
240            if txid == 0 {
241                handler.on_event(&self.sender, ordinal, buffer);
242            } else {
243                let mut responses = self.sender.shared.responses.lock().unwrap();
244                let locker = responses
245                    .get(txid - 1)
246                    .ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
247
248                match locker.write(ordinal, buffer) {
249                    // Reader didn't cancel
250                    Ok(false) => (),
251                    // Reader canceled, we can drop the entry
252                    Ok(true) => responses.free(txid - 1),
253                    Err(LockerError::NotWriteable) => {
254                        return Err(ProtocolError::UnrequestedResponse(txid));
255                    }
256                    Err(LockerError::MismatchedOrdinal { expected, actual }) => {
257                        return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
258                    }
259                }
260            }
261        }
262
263        Ok(())
264    }
265}
266
267/// A client handler which ignores any incoming events.
268pub struct IgnoreEvents;
269
270impl<T: Transport> ClientHandler<T> for IgnoreEvents {
271    fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
272}