fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_next_codec::{Encode, EncodeError, EncoderExt};
12use pin_project::{pin_project, pinned_drop};
13
14use crate::concurrency::sync::{Arc, Mutex};
15use crate::endpoints::connection::{Connection, ORDINAL_EPITAPH};
16use crate::endpoints::lockers::{LockerError, Lockers};
17use crate::{ProtocolError, SendFuture, Transport, decode_epitaph, decode_header, encode_header};
18
19struct ClientInner<T: Transport> {
20    connection: Connection<T>,
21    responses: Mutex<Lockers<T::RecvBuffer>>,
22}
23
24impl<T: Transport> ClientInner<T> {
25    fn new(shared: T::Shared) -> Self {
26        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
27    }
28}
29
30/// A client endpoint.
31pub struct Client<T: Transport> {
32    inner: Arc<ClientInner<T>>,
33}
34
35impl<T: Transport> Drop for Client<T> {
36    fn drop(&mut self) {
37        if Arc::strong_count(&self.inner) == 2 {
38            // This was the last reference to the connection other than the one
39            // in the dispatcher itself. Stop the connection.
40            self.close();
41        }
42    }
43}
44
45impl<T: Transport> Client<T> {
46    /// Closes the channel from the client end.
47    pub fn close(&self) {
48        self.inner.connection.stop();
49    }
50
51    /// Send a request.
52    pub fn send_one_way<M>(
53        &self,
54        ordinal: u64,
55        request: M,
56    ) -> Result<SendFuture<'_, T>, EncodeError>
57    where
58        M: Encode<T::SendBuffer>,
59    {
60        self.send_message(0, ordinal, request)
61    }
62
63    /// Send a request and await for a response.
64    pub fn send_two_way<M>(
65        &self,
66        ordinal: u64,
67        request: M,
68    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
69    where
70        M: Encode<T::SendBuffer>,
71    {
72        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
73
74        // Send with txid = index + 1 because indices start at 0.
75        match self.send_message(index + 1, ordinal, request) {
76            Ok(send_future) => {
77                Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
78            }
79            Err(e) => {
80                self.inner.responses.lock().unwrap().free(index);
81                Err(e)
82            }
83        }
84    }
85
86    fn send_message<M>(
87        &self,
88        txid: u32,
89        ordinal: u64,
90        message: M,
91    ) -> Result<SendFuture<'_, T>, EncodeError>
92    where
93        M: Encode<T::SendBuffer>,
94    {
95        self.inner.connection.send_message(|buffer| {
96            encode_header::<T>(buffer, txid, ordinal)?;
97            buffer.encode_next(message)
98        })
99    }
100}
101
102impl<T: Transport> Clone for Client<T> {
103    fn clone(&self) -> Self {
104        Self { inner: self.inner.clone() }
105    }
106}
107
108/// A future for a pending response to a two-way message.
109pub struct TwoWayResponseFuture<'a, T: Transport> {
110    inner: &'a ClientInner<T>,
111    index: Option<u32>,
112}
113
114impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
115    fn drop(&mut self) {
116        // If `index` is `Some`, then we still need to free our locker.
117        if let Some(index) = self.index {
118            let mut responses = self.inner.responses.lock().unwrap();
119            if responses.get(index).unwrap().cancel() {
120                responses.free(index);
121            }
122        }
123    }
124}
125
126impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
127    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
128
129    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130        let this = Pin::into_inner(self);
131        let Some(index) = this.index else {
132            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
133        };
134
135        let mut responses = this.inner.responses.lock().unwrap();
136        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
137            Ok(ready)
138        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
139            Err(termination_reason)
140        } else {
141            return Poll::Pending;
142        };
143
144        responses.free(index);
145        this.index = None;
146        Poll::Ready(ready)
147    }
148}
149
150/// A future for a sending a two-way FIDL message.
151#[pin_project(PinnedDrop)]
152pub struct TwoWayRequestFuture<'a, T: Transport> {
153    inner: &'a ClientInner<T>,
154    index: Option<u32>,
155    #[pin]
156    send_future: SendFuture<'a, T>,
157}
158
159#[pinned_drop]
160impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
161    fn drop(self: Pin<&mut Self>) {
162        if let Some(index) = self.index {
163            let mut responses = self.inner.responses.lock().unwrap();
164
165            // The future was canceled before it could be sent. The transaction
166            // ID was never used, so it's safe to immediately reuse.
167            responses.free(index);
168        }
169    }
170}
171
172impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
173    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
174
175    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176        let this = self.project();
177
178        let Some(index) = *this.index else {
179            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
180        };
181
182        let result = ready!(this.send_future.poll(cx));
183        *this.index = None;
184        if let Err(error) = result {
185            // The send failed. Free the locker and return an error.
186            this.inner.responses.lock().unwrap().free(index);
187            Poll::Ready(Err(error))
188        } else {
189            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
190        }
191    }
192}
193
194/// A type which handles incoming events for a client.
195pub trait ClientHandler<T: Transport> {
196    /// Handles a received client event, returning the appropriate flow control
197    /// to perform.
198    ///
199    /// The client cannot handle more messages until `on_event` completes. If
200    /// `on_event` should handle requests in parallel, it should spawn a new
201    /// async task and return.
202    fn on_event(
203        &mut self,
204        ordinal: u64,
205        buffer: T::RecvBuffer,
206    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
207}
208
209/// A dispatcher for a client endpoint.
210///
211/// A client dispatcher receives all of the incoming messages and dispatches them to the client
212/// handler and two-way futures. It acts as the message pump for the client.
213///
214/// The dispatcher must be actively polled to receive events and two-way message responses. If the
215/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
216/// and two-way message futures will not receive their responses.
217pub struct ClientDispatcher<T: Transport> {
218    inner: Arc<ClientInner<T>>,
219    exclusive: T::Exclusive,
220    is_terminated: bool,
221}
222
223impl<T: Transport> Drop for ClientDispatcher<T> {
224    fn drop(&mut self) {
225        if !self.is_terminated {
226            // SAFETY: We checked that the connection has not been terminated.
227            unsafe {
228                self.terminate(ProtocolError::Stopped);
229            }
230        }
231    }
232}
233
234impl<T: Transport> ClientDispatcher<T> {
235    /// Creates a new client from a transport.
236    pub fn new(transport: T) -> Self {
237        let (shared, exclusive) = transport.split();
238        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
239    }
240
241    /// # Safety
242    ///
243    /// The connection must not yet be terminated.
244    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
245        // SAFETY: We checked that the connection has not been terminated.
246        unsafe {
247            self.inner.connection.terminate(error);
248        }
249        self.inner.responses.lock().unwrap().wake_all();
250    }
251
252    /// Returns a client for the dispatcher.
253    ///
254    /// When the last `Client` is dropped, the dispatcher will be stopped.
255    pub fn client(&self) -> Client<T> {
256        Client { inner: self.inner.clone() }
257    }
258
259    /// Runs the client with the provided handler.
260    pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
261    where
262        H: ClientHandler<T>,
263    {
264        // We may assume that the connection has not been terminated because
265        // connections are only terminated by `run` and `drop`. Neither of those
266        // could have been called before this method because `run` consumes
267        // `self` and `drop` is only ever called once.
268
269        let error = loop {
270            // SAFETY: The connection has not been terminated.
271            let result = unsafe { self.run_one(&mut handler).await };
272            if let Err(error) = result {
273                break error;
274            }
275        };
276
277        // SAFETY: The connection has not been terminated.
278        unsafe {
279            self.terminate(error.clone());
280        }
281        self.is_terminated = true;
282
283        match error {
284            // We consider clients to have finished successfully only if they
285            // stop themselves manually.
286            ProtocolError::Stopped => Ok(handler),
287
288            // Otherwise, the client finished with an error.
289            _ => Err(error),
290        }
291    }
292
293    /// # Safety
294    ///
295    /// The connection must not be terminated.
296    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
297    where
298        H: ClientHandler<T>,
299    {
300        // SAFETY: The caller guaranteed that the connection is not terminated.
301        let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
302
303        let (txid, ordinal) =
304            decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
305
306        if ordinal == ORDINAL_EPITAPH {
307            let epitaph =
308                decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
309            return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
310        } else if txid == 0 {
311            handler.on_event(ordinal, buffer).await?;
312        } else {
313            let mut responses = self.inner.responses.lock().unwrap();
314            let locker = responses
315                .get(txid - 1)
316                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
317
318            match locker.write(ordinal, buffer) {
319                // Reader didn't cancel
320                Ok(false) => (),
321                // Reader canceled, we can drop the entry
322                Ok(true) => responses.free(txid - 1),
323                Err(LockerError::NotWriteable) => {
324                    return Err(ProtocolError::UnrequestedResponse { txid });
325                }
326                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
327                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
328                }
329            }
330        }
331
332        Ok(())
333    }
334
335    /// Runs the client with the [`IgnoreEvents`] handler.
336    pub async fn run_client(self) -> Result<(), ProtocolError<T::Error>> {
337        self.run(IgnoreEvents).await.map(|_| ())
338    }
339}
340
341/// A client handler which ignores any incoming events.
342pub struct IgnoreEvents;
343
344impl<T: Transport> ClientHandler<T> for IgnoreEvents {
345    async fn on_event(&mut self, _: u64, _: T::RecvBuffer) -> Result<(), ProtocolError<T::Error>> {
346        Ok(())
347    }
348}