Skip to main content

fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_constants::EPITAPH_ORDINAL;
12use fidl_next_codec::{AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire};
13use pin_project::{pin_project, pinned_drop};
14
15use crate::concurrency::sync::{Arc, Mutex};
16use crate::endpoints::connection::Connection;
17use crate::endpoints::lockers::{LockerError, Lockers};
18use crate::wire::{Epitaph, MessageHeader};
19use crate::{Body, Flexibility, ProtocolError, SendFuture, Transport};
20
21struct ClientInner<T: Transport> {
22    connection: Connection<T>,
23    responses: Mutex<Lockers<Body<T>>>,
24}
25
26impl<T: Transport> ClientInner<T> {
27    fn new(shared: T::Shared) -> Self {
28        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
29    }
30}
31
32/// A client endpoint.
33pub struct Client<T: Transport> {
34    inner: Arc<ClientInner<T>>,
35}
36
37impl<T: Transport> Drop for Client<T> {
38    fn drop(&mut self) {
39        if Arc::strong_count(&self.inner) == 2 {
40            // This was the last reference to the connection other than the one
41            // in the dispatcher itself. Stop the connection.
42            self.close();
43        }
44    }
45}
46
47impl<T: Transport> Client<T> {
48    /// Closes the channel from the client end.
49    pub fn close(&self) {
50        self.inner.connection.stop();
51    }
52
53    /// Send a request.
54    pub fn send_one_way<W>(
55        &self,
56        ordinal: u64,
57        flexibility: Flexibility,
58        request: impl Encode<W, T::SendBuffer>,
59    ) -> Result<SendFuture<'_, T>, EncodeError>
60    where
61        W: Wire<Constraint = ()>,
62    {
63        self.send_message(0, ordinal, flexibility, request)
64    }
65
66    /// Send a request and await for a response.
67    pub fn send_two_way<W>(
68        &self,
69        ordinal: u64,
70        flexibility: Flexibility,
71        request: impl Encode<W, T::SendBuffer>,
72    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
73    where
74        W: Wire<Constraint = ()>,
75    {
76        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
77
78        // Send with txid = index + 1 because indices start at 0.
79        match self.send_message(index + 1, ordinal, flexibility, request) {
80            Ok(send_future) => {
81                Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
82            }
83            Err(e) => {
84                self.inner.responses.lock().unwrap().free(index);
85                Err(e)
86            }
87        }
88    }
89
90    fn send_message<W>(
91        &self,
92        txid: u32,
93        ordinal: u64,
94        flexibility: Flexibility,
95        message: impl Encode<W, T::SendBuffer>,
96    ) -> Result<SendFuture<'_, T>, EncodeError>
97    where
98        W: Wire<Constraint = ()>,
99    {
100        self.inner.connection.send_message(|buffer| {
101            buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
102            buffer.encode_next(message)
103        })
104    }
105}
106
107impl<T: Transport> Clone for Client<T> {
108    fn clone(&self) -> Self {
109        Self { inner: self.inner.clone() }
110    }
111}
112
113/// A future for a pending response to a two-way message.
114pub struct TwoWayResponseFuture<'a, T: Transport> {
115    inner: &'a ClientInner<T>,
116    index: Option<u32>,
117}
118
119impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
120    fn drop(&mut self) {
121        // If `index` is `Some`, then we still need to free our locker.
122        if let Some(index) = self.index {
123            let mut responses = self.inner.responses.lock().unwrap();
124            if responses.get(index).unwrap().cancel() {
125                responses.free(index);
126            }
127        }
128    }
129}
130
131impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
132    type Output = Result<Body<T>, ProtocolError<T::Error>>;
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        let this = Pin::into_inner(self);
136        let Some(index) = this.index else {
137            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
138        };
139
140        let mut responses = this.inner.responses.lock().unwrap();
141        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
142            Ok(ready)
143        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
144            Err(termination_reason)
145        } else {
146            return Poll::Pending;
147        };
148
149        responses.free(index);
150        this.index = None;
151        Poll::Ready(ready)
152    }
153}
154
155/// A future for a sending a two-way FIDL message.
156#[pin_project(PinnedDrop)]
157pub struct TwoWayRequestFuture<'a, T: Transport> {
158    inner: &'a ClientInner<T>,
159    index: Option<u32>,
160    #[pin]
161    send_future: SendFuture<'a, T>,
162}
163
164#[pinned_drop]
165impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
166    fn drop(self: Pin<&mut Self>) {
167        if let Some(index) = self.index {
168            let mut responses = self.inner.responses.lock().unwrap();
169
170            // The future was canceled before it could be sent. The transaction
171            // ID was never used, so it's safe to immediately reuse.
172            responses.free(index);
173        }
174    }
175}
176
177impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
178    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
179
180    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181        let this = self.project();
182
183        let Some(index) = *this.index else {
184            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
185        };
186
187        let result = ready!(this.send_future.poll(cx));
188        *this.index = None;
189        if let Err(error) = result {
190            // The send failed. Free the locker and return an error.
191            this.inner.responses.lock().unwrap().free(index);
192            Poll::Ready(Err(error))
193        } else {
194            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
195        }
196    }
197}
198
199/// A type which handles incoming events for a local client.
200///
201/// This is a variant of [`ClientHandler`] that does not require implementing
202/// `Send` and only supports local-thread executors.
203pub trait LocalClientHandler<T: Transport> {
204    /// Handles a received client event.
205    ///
206    /// See [`ClientHandler::on_event`] for more information.
207    fn on_event(
208        &mut self,
209        ordinal: u64,
210        flexibility: Flexibility,
211        body: Body<T>,
212    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>>;
213}
214
215/// A type which handles incoming events for a client.
216pub trait ClientHandler<T: Transport>: Send {
217    /// Handles a received client event.
218    ///
219    /// The client cannot handle more messages until `on_event` completes. If
220    /// `on_event` may block, or would perform asynchronous work that takes a
221    /// long time, it should offload work to an async task and return.
222    fn on_event(
223        &mut self,
224        ordinal: u64,
225        flexibility: Flexibility,
226        body: Body<T>,
227    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
228}
229
230/// An adapter for a [`ClientHandler`] which implements [`LocalClientHandler`].
231#[repr(transparent)]
232pub struct ClientHandlerToLocalAdapter<H>(H);
233
234impl<T, H> LocalClientHandler<T> for ClientHandlerToLocalAdapter<H>
235where
236    T: Transport,
237    H: ClientHandler<T>,
238{
239    #[inline]
240    fn on_event(
241        &mut self,
242        ordinal: u64,
243        flexibility: Flexibility,
244        body: Body<T>,
245    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> {
246        self.0.on_event(ordinal, flexibility, body)
247    }
248}
249
250/// A dispatcher for a client endpoint.
251///
252/// A client dispatcher receives all of the incoming messages and dispatches them to the client
253/// handler and two-way futures. It acts as the message pump for the client.
254///
255/// The dispatcher must be actively polled to receive events and two-way message responses. If the
256/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
257/// and two-way message futures will not receive their responses.
258pub struct ClientDispatcher<T: Transport> {
259    inner: Arc<ClientInner<T>>,
260    exclusive: T::Exclusive,
261    is_terminated: bool,
262}
263
264impl<T: Transport> Drop for ClientDispatcher<T> {
265    fn drop(&mut self) {
266        if !self.is_terminated {
267            // SAFETY: We checked that the connection has not been terminated.
268            unsafe {
269                self.terminate(ProtocolError::Stopped);
270            }
271        }
272    }
273}
274
275impl<T: Transport> ClientDispatcher<T> {
276    /// Creates a new client from a transport.
277    pub fn new(transport: T) -> Self {
278        let (shared, exclusive) = transport.split();
279        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
280    }
281
282    /// # Safety
283    ///
284    /// The connection must not yet be terminated.
285    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
286        // SAFETY: We checked that the connection has not been terminated.
287        unsafe {
288            self.inner.connection.terminate(error);
289        }
290        self.inner.responses.lock().unwrap().wake_all();
291    }
292
293    /// Returns a client for the dispatcher.
294    ///
295    /// When the last `Client` is dropped, the dispatcher will be stopped.
296    pub fn client(&self) -> Client<T> {
297        Client { inner: self.inner.clone() }
298    }
299
300    /// Runs the client with the provided handler.
301    pub async fn run<H>(self, handler: H) -> Result<H, ProtocolError<T::Error>>
302    where
303        H: ClientHandler<T>,
304    {
305        // The bounds on `H` prove that the future returned by `run_local` is
306        // `Send`.
307        self.run_local(ClientHandlerToLocalAdapter(handler)).await.map(|adapter| adapter.0)
308    }
309
310    /// Runs the client with the provided handler.
311    pub async fn run_local<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
312    where
313        H: LocalClientHandler<T>,
314    {
315        // We may assume that the connection has not been terminated because
316        // connections are only terminated by `run` and `drop`. Neither of those
317        // could have been called before this method because `run` consumes
318        // `self` and `drop` is only ever called once.
319
320        let error = loop {
321            // SAFETY: The connection has not been terminated.
322            let result = unsafe { self.run_one(&mut handler).await };
323            if let Err(error) = result {
324                break error;
325            }
326        };
327
328        // SAFETY: The connection has not been terminated.
329        unsafe {
330            self.terminate(error.clone());
331        }
332        self.is_terminated = true;
333
334        match error {
335            // We consider clients to have finished successfully only if they
336            // stop themselves manually.
337            ProtocolError::Stopped => Ok(handler),
338
339            // Otherwise, the client finished with an error.
340            _ => Err(error),
341        }
342    }
343
344    /// # Safety
345    ///
346    /// The connection must not be terminated.
347    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
348    where
349        H: LocalClientHandler<T>,
350    {
351        // SAFETY: The caller guaranteed that the connection is not terminated.
352        let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
353
354        // This expression is really awkward due to a limitation in rustc's
355        // liveness analysis for local variables. We need to avoid holding
356        // `decoder` across `.await`s because it may not be `Send` and tasks may
357        // migrate threads between polls. We should be able to just
358        // `drop(decoder)` before any `.await`, but rustc is overly conservative
359        // and still considers `decoder` as live at the `.await` for that
360        // analysis. The only way to convince rustc that `decoder` is not live
361        // at that await point is to keep the lexical scope containing `decoder`
362        // free of `.await`s.
363        //
364        // See https://github.com/rust-lang/rust/issues/63768 for more details.
365        let header = {
366            let mut decoder = buffer.as_decoder();
367
368            let header = decoder
369                .decode_prefix::<MessageHeader>()
370                .map_err(ProtocolError::InvalidMessageHeader)?;
371
372            // Check if the ordinal is the epitaph so we can immediately decode
373            // and return it. We do this before dropping `decoder` so that we
374            // don't have to re-acquire it and wrap it in `Body`.
375            if header.ordinal == EPITAPH_ORDINAL {
376                let epitaph =
377                    decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
378                return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
379            }
380
381            header
382        };
383
384        if header.txid == 0 {
385            handler.on_event(*header.ordinal, header.flexibility(), Body::new(buffer)).await?;
386        } else {
387            let mut responses = self.inner.responses.lock().unwrap();
388            let locker = responses
389                .get(*header.txid - 1)
390                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid: *header.txid })?;
391
392            match locker.write(*header.ordinal, Body::new(buffer)) {
393                // Reader didn't cancel
394                Ok(false) => (),
395                // Reader canceled, we can drop the entry
396                Ok(true) => responses.free(*header.txid - 1),
397                Err(LockerError::NotWriteable) => {
398                    return Err(ProtocolError::UnrequestedResponse { txid: *header.txid });
399                }
400                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
401                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
402                }
403            }
404        }
405
406        Ok(())
407    }
408}