fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_constants::EPITAPH_ORDINAL;
12use fidl_next_codec::{Constrained, Encode, EncodeError, EncoderExt, Wire};
13use pin_project::{pin_project, pinned_drop};
14
15use crate::concurrency::sync::{Arc, Mutex};
16use crate::endpoints::connection::Connection;
17use crate::endpoints::lockers::{LockerError, Lockers};
18use crate::{
19    Flexibility, ProtocolError, SendFuture, Transport, decode_epitaph, decode_header, encode_header,
20};
21
22struct ClientInner<T: Transport> {
23    connection: Connection<T>,
24    responses: Mutex<Lockers<T::RecvBuffer>>,
25}
26
27impl<T: Transport> ClientInner<T> {
28    fn new(shared: T::Shared) -> Self {
29        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
30    }
31}
32
33/// A client endpoint.
34pub struct Client<T: Transport> {
35    inner: Arc<ClientInner<T>>,
36}
37
38impl<T: Transport> Drop for Client<T> {
39    fn drop(&mut self) {
40        if Arc::strong_count(&self.inner) == 2 {
41            // This was the last reference to the connection other than the one
42            // in the dispatcher itself. Stop the connection.
43            self.close();
44        }
45    }
46}
47
48impl<T: Transport> Client<T> {
49    /// Closes the channel from the client end.
50    pub fn close(&self) {
51        self.inner.connection.stop();
52    }
53
54    /// Send a request.
55    pub fn send_one_way<W>(
56        &self,
57        ordinal: u64,
58        flexibility: Flexibility,
59        request: impl Encode<W, T::SendBuffer>,
60    ) -> Result<SendFuture<'_, T>, EncodeError>
61    where
62        W: Constrained<Constraint = ()> + Wire,
63    {
64        self.send_message(0, ordinal, flexibility, request)
65    }
66
67    /// Send a request and await for a response.
68    pub fn send_two_way<W>(
69        &self,
70        ordinal: u64,
71        flexibility: Flexibility,
72        request: impl Encode<W, T::SendBuffer>,
73    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
74    where
75        W: Constrained<Constraint = ()> + Wire,
76    {
77        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
78
79        // Send with txid = index + 1 because indices start at 0.
80        match self.send_message(index + 1, ordinal, flexibility, request) {
81            Ok(send_future) => {
82                Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
83            }
84            Err(e) => {
85                self.inner.responses.lock().unwrap().free(index);
86                Err(e)
87            }
88        }
89    }
90
91    fn send_message<W>(
92        &self,
93        txid: u32,
94        ordinal: u64,
95        flexibility: Flexibility,
96        message: impl Encode<W, T::SendBuffer>,
97    ) -> Result<SendFuture<'_, T>, EncodeError>
98    where
99        W: Constrained<Constraint = ()> + Wire,
100    {
101        self.inner.connection.send_message(|buffer| {
102            encode_header::<T>(buffer, txid, ordinal, flexibility)?;
103            buffer.encode_next(message, ())
104        })
105    }
106}
107
108impl<T: Transport> Clone for Client<T> {
109    fn clone(&self) -> Self {
110        Self { inner: self.inner.clone() }
111    }
112}
113
114/// A future for a pending response to a two-way message.
115pub struct TwoWayResponseFuture<'a, T: Transport> {
116    inner: &'a ClientInner<T>,
117    index: Option<u32>,
118}
119
120impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
121    fn drop(&mut self) {
122        // If `index` is `Some`, then we still need to free our locker.
123        if let Some(index) = self.index {
124            let mut responses = self.inner.responses.lock().unwrap();
125            if responses.get(index).unwrap().cancel() {
126                responses.free(index);
127            }
128        }
129    }
130}
131
132impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
133    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
134
135    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136        let this = Pin::into_inner(self);
137        let Some(index) = this.index else {
138            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
139        };
140
141        let mut responses = this.inner.responses.lock().unwrap();
142        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
143            Ok(ready)
144        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
145            Err(termination_reason)
146        } else {
147            return Poll::Pending;
148        };
149
150        responses.free(index);
151        this.index = None;
152        Poll::Ready(ready)
153    }
154}
155
156/// A future for a sending a two-way FIDL message.
157#[pin_project(PinnedDrop)]
158pub struct TwoWayRequestFuture<'a, T: Transport> {
159    inner: &'a ClientInner<T>,
160    index: Option<u32>,
161    #[pin]
162    send_future: SendFuture<'a, T>,
163}
164
165#[pinned_drop]
166impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
167    fn drop(self: Pin<&mut Self>) {
168        if let Some(index) = self.index {
169            let mut responses = self.inner.responses.lock().unwrap();
170
171            // The future was canceled before it could be sent. The transaction
172            // ID was never used, so it's safe to immediately reuse.
173            responses.free(index);
174        }
175    }
176}
177
178impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
179    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
180
181    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
182        let this = self.project();
183
184        let Some(index) = *this.index else {
185            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
186        };
187
188        let result = ready!(this.send_future.poll(cx));
189        *this.index = None;
190        if let Err(error) = result {
191            // The send failed. Free the locker and return an error.
192            this.inner.responses.lock().unwrap().free(index);
193            Poll::Ready(Err(error))
194        } else {
195            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
196        }
197    }
198}
199
200/// A type which handles incoming events for a client.
201pub trait ClientHandler<T: Transport> {
202    /// Handles a received client event, returning the appropriate flow control
203    /// to perform.
204    ///
205    /// The client cannot handle more messages until `on_event` completes. If
206    /// `on_event` should handle requests in parallel, it should spawn a new
207    /// async task and return.
208    fn on_event(
209        &mut self,
210        ordinal: u64,
211        flexibility: Flexibility,
212        buffer: T::RecvBuffer,
213    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
214}
215
216/// A dispatcher for a client endpoint.
217///
218/// A client dispatcher receives all of the incoming messages and dispatches them to the client
219/// handler and two-way futures. It acts as the message pump for the client.
220///
221/// The dispatcher must be actively polled to receive events and two-way message responses. If the
222/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
223/// and two-way message futures will not receive their responses.
224pub struct ClientDispatcher<T: Transport> {
225    inner: Arc<ClientInner<T>>,
226    exclusive: T::Exclusive,
227    is_terminated: bool,
228}
229
230impl<T: Transport> Drop for ClientDispatcher<T> {
231    fn drop(&mut self) {
232        if !self.is_terminated {
233            // SAFETY: We checked that the connection has not been terminated.
234            unsafe {
235                self.terminate(ProtocolError::Stopped);
236            }
237        }
238    }
239}
240
241impl<T: Transport> ClientDispatcher<T> {
242    /// Creates a new client from a transport.
243    pub fn new(transport: T) -> Self {
244        let (shared, exclusive) = transport.split();
245        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
246    }
247
248    /// # Safety
249    ///
250    /// The connection must not yet be terminated.
251    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
252        // SAFETY: We checked that the connection has not been terminated.
253        unsafe {
254            self.inner.connection.terminate(error);
255        }
256        self.inner.responses.lock().unwrap().wake_all();
257    }
258
259    /// Returns a client for the dispatcher.
260    ///
261    /// When the last `Client` is dropped, the dispatcher will be stopped.
262    pub fn client(&self) -> Client<T> {
263        Client { inner: self.inner.clone() }
264    }
265
266    /// Runs the client with the provided handler.
267    pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
268    where
269        H: ClientHandler<T>,
270    {
271        // We may assume that the connection has not been terminated because
272        // connections are only terminated by `run` and `drop`. Neither of those
273        // could have been called before this method because `run` consumes
274        // `self` and `drop` is only ever called once.
275
276        let error = loop {
277            // SAFETY: The connection has not been terminated.
278            let result = unsafe { self.run_one(&mut handler).await };
279            if let Err(error) = result {
280                break error;
281            }
282        };
283
284        // SAFETY: The connection has not been terminated.
285        unsafe {
286            self.terminate(error.clone());
287        }
288        self.is_terminated = true;
289
290        match error {
291            // We consider clients to have finished successfully only if they
292            // stop themselves manually.
293            ProtocolError::Stopped => Ok(handler),
294
295            // Otherwise, the client finished with an error.
296            _ => Err(error),
297        }
298    }
299
300    /// # Safety
301    ///
302    /// The connection must not be terminated.
303    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
304    where
305        H: ClientHandler<T>,
306    {
307        // SAFETY: The caller guaranteed that the connection is not terminated.
308        let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
309
310        let (txid, ordinal, flexibility) =
311            decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
312
313        if ordinal == EPITAPH_ORDINAL {
314            let epitaph =
315                decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
316            return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
317        } else if txid == 0 {
318            handler.on_event(ordinal, flexibility, buffer).await?;
319        } else {
320            let mut responses = self.inner.responses.lock().unwrap();
321            let locker = responses
322                .get(txid - 1)
323                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
324
325            match locker.write(ordinal, buffer) {
326                // Reader didn't cancel
327                Ok(false) => (),
328                // Reader canceled, we can drop the entry
329                Ok(true) => responses.free(txid - 1),
330                Err(LockerError::NotWriteable) => {
331                    return Err(ProtocolError::UnrequestedResponse { txid });
332                }
333                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
334                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
335                }
336            }
337        }
338
339        Ok(())
340    }
341}