Skip to main content

fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_constants::EPITAPH_ORDINAL;
12use fidl_next_codec::{AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire};
13use pin_project::{pin_project, pinned_drop};
14
15use crate::concurrency::sync::{Arc, Mutex};
16use crate::endpoints::connection::Connection;
17use crate::endpoints::lockers::{LockerError, Lockers};
18use crate::wire::{Epitaph, MessageHeader};
19use crate::{Body, Flexibility, ProtocolError, SendFuture, Transport};
20
21struct ClientInner<T: Transport> {
22    connection: Connection<T>,
23    responses: Mutex<Lockers<Body<T>>>,
24}
25
26impl<T: Transport> ClientInner<T> {
27    fn new(shared: T::Shared) -> Self {
28        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
29    }
30}
31
32/// A client endpoint.
33pub struct Client<T: Transport> {
34    inner: Arc<ClientInner<T>>,
35}
36
37impl<T: Transport> Drop for Client<T> {
38    fn drop(&mut self) {
39        if Arc::strong_count(&self.inner) == 2 {
40            // This was the last reference to the connection other than the one
41            // in the dispatcher itself. Stop the connection.
42            self.close();
43        }
44    }
45}
46
47impl<T: Transport> Client<T> {
48    /// Closes the channel from the client end.
49    pub fn close(&self) {
50        self.inner.connection.stop();
51    }
52
53    /// Send a request.
54    pub fn send_one_way<W>(
55        &self,
56        ordinal: u64,
57        flexibility: Flexibility,
58        request: impl Encode<W, T::SendBuffer>,
59    ) -> Result<SendFuture<'_, T>, EncodeError>
60    where
61        W: Wire<Constraint = ()>,
62    {
63        self.send_message(0, ordinal, flexibility, request)
64    }
65
66    /// Send a request and await for a response.
67    pub fn send_two_way<W>(
68        &self,
69        ordinal: u64,
70        flexibility: Flexibility,
71        request: impl Encode<W, T::SendBuffer>,
72    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
73    where
74        W: Wire<Constraint = ()>,
75    {
76        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
77
78        // Send with txid = index + 1 because indices start at 0.
79        match self.send_message(index + 1, ordinal, flexibility, request) {
80            Ok(send_future) => {
81                Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
82            }
83            Err(e) => {
84                self.inner.responses.lock().unwrap().free(index);
85                Err(e)
86            }
87        }
88    }
89
90    fn send_message<W>(
91        &self,
92        txid: u32,
93        ordinal: u64,
94        flexibility: Flexibility,
95        message: impl Encode<W, T::SendBuffer>,
96    ) -> Result<SendFuture<'_, T>, EncodeError>
97    where
98        W: Wire<Constraint = ()>,
99    {
100        self.inner.connection.send_message(|buffer| {
101            buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
102            buffer.encode_next(message)
103        })
104    }
105}
106
107impl<T: Transport> Clone for Client<T> {
108    fn clone(&self) -> Self {
109        Self { inner: self.inner.clone() }
110    }
111}
112
113/// A future for a pending response to a two-way message.
114pub struct TwoWayResponseFuture<'a, T: Transport> {
115    inner: &'a ClientInner<T>,
116    index: Option<u32>,
117}
118
119impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
120    fn drop(&mut self) {
121        // If `index` is `Some`, then we still need to free our locker.
122        if let Some(index) = self.index {
123            let mut responses = self.inner.responses.lock().unwrap();
124            if responses.get(index).unwrap().cancel() {
125                responses.free(index);
126            }
127        }
128    }
129}
130
131impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
132    type Output = Result<Body<T>, ProtocolError<T::Error>>;
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        let this = Pin::into_inner(self);
136        let Some(index) = this.index else {
137            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
138        };
139
140        let mut responses = this.inner.responses.lock().unwrap();
141        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
142            Ok(ready)
143        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
144            Err(termination_reason)
145        } else {
146            return Poll::Pending;
147        };
148
149        responses.free(index);
150        this.index = None;
151        Poll::Ready(ready)
152    }
153}
154
155/// A future for a sending a two-way FIDL message.
156#[pin_project(PinnedDrop)]
157pub struct TwoWayRequestFuture<'a, T: Transport> {
158    inner: &'a ClientInner<T>,
159    index: Option<u32>,
160    #[pin]
161    send_future: SendFuture<'a, T>,
162}
163
164#[pinned_drop]
165impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
166    fn drop(self: Pin<&mut Self>) {
167        if let Some(index) = self.index {
168            let mut responses = self.inner.responses.lock().unwrap();
169
170            // The future was canceled before it could be sent. The transaction
171            // ID was never used, so it's safe to immediately reuse.
172            responses.free(index);
173        }
174    }
175}
176
177impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
178    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
179
180    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181        let this = self.project();
182
183        let Some(index) = *this.index else {
184            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
185        };
186
187        let result = ready!(this.send_future.poll(cx));
188        *this.index = None;
189        if let Err(error) = result {
190            // The send failed. Free the locker and return an error.
191            this.inner.responses.lock().unwrap().free(index);
192            Poll::Ready(Err(error))
193        } else {
194            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
195        }
196    }
197}
198
199/// A type which handles incoming events for a client.
200pub trait ClientHandler<T: Transport> {
201    /// Handles a received client event, returning the appropriate flow control
202    /// to perform.
203    ///
204    /// The client cannot handle more messages until `on_event` completes. If
205    /// `on_event` should handle requests in parallel, it should spawn a new
206    /// async task and return.
207    fn on_event(
208        &mut self,
209        ordinal: u64,
210        flexibility: Flexibility,
211        body: Body<T>,
212    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
213}
214
215/// A dispatcher for a client endpoint.
216///
217/// A client dispatcher receives all of the incoming messages and dispatches them to the client
218/// handler and two-way futures. It acts as the message pump for the client.
219///
220/// The dispatcher must be actively polled to receive events and two-way message responses. If the
221/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
222/// and two-way message futures will not receive their responses.
223pub struct ClientDispatcher<T: Transport> {
224    inner: Arc<ClientInner<T>>,
225    exclusive: T::Exclusive,
226    is_terminated: bool,
227}
228
229impl<T: Transport> Drop for ClientDispatcher<T> {
230    fn drop(&mut self) {
231        if !self.is_terminated {
232            // SAFETY: We checked that the connection has not been terminated.
233            unsafe {
234                self.terminate(ProtocolError::Stopped);
235            }
236        }
237    }
238}
239
240impl<T: Transport> ClientDispatcher<T> {
241    /// Creates a new client from a transport.
242    pub fn new(transport: T) -> Self {
243        let (shared, exclusive) = transport.split();
244        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
245    }
246
247    /// # Safety
248    ///
249    /// The connection must not yet be terminated.
250    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
251        // SAFETY: We checked that the connection has not been terminated.
252        unsafe {
253            self.inner.connection.terminate(error);
254        }
255        self.inner.responses.lock().unwrap().wake_all();
256    }
257
258    /// Returns a client for the dispatcher.
259    ///
260    /// When the last `Client` is dropped, the dispatcher will be stopped.
261    pub fn client(&self) -> Client<T> {
262        Client { inner: self.inner.clone() }
263    }
264
265    /// Runs the client with the provided handler.
266    pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
267    where
268        H: ClientHandler<T>,
269    {
270        // We may assume that the connection has not been terminated because
271        // connections are only terminated by `run` and `drop`. Neither of those
272        // could have been called before this method because `run` consumes
273        // `self` and `drop` is only ever called once.
274
275        let error = loop {
276            // SAFETY: The connection has not been terminated.
277            let result = unsafe { self.run_one(&mut handler).await };
278            if let Err(error) = result {
279                break error;
280            }
281        };
282
283        // SAFETY: The connection has not been terminated.
284        unsafe {
285            self.terminate(error.clone());
286        }
287        self.is_terminated = true;
288
289        match error {
290            // We consider clients to have finished successfully only if they
291            // stop themselves manually.
292            ProtocolError::Stopped => Ok(handler),
293
294            // Otherwise, the client finished with an error.
295            _ => Err(error),
296        }
297    }
298
299    /// # Safety
300    ///
301    /// The connection must not be terminated.
302    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
303    where
304        H: ClientHandler<T>,
305    {
306        // SAFETY: The caller guaranteed that the connection is not terminated.
307        let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
308
309        // This expression is really awkward due to a limitation in rustc's
310        // liveness analysis for local variables. We need to avoid holding
311        // `decoder` across `.await`s because it may not be `Send` and tasks may
312        // migrate threads between polls. We should be able to just
313        // `drop(decoder)` before any `.await`, but rustc is overly conservative
314        // and still considers `decoder` as live at the `.await` for that
315        // analysis. The only way to convince rustc that `decoder` is not live
316        // at that await point is to keep the lexical scope containing `decoder`
317        // free of `.await`s.
318        //
319        // See https://github.com/rust-lang/rust/issues/63768 for more details.
320        let header = {
321            let mut decoder = buffer.as_decoder();
322
323            let header = decoder
324                .decode_prefix::<MessageHeader>()
325                .map_err(ProtocolError::InvalidMessageHeader)?;
326
327            // Check if the ordinal is the epitaph so we can immediately decode
328            // and return it. We do this before dropping `decoder` so that we
329            // don't have to re-acquire it and wrap it in `Body`.
330            if header.ordinal == EPITAPH_ORDINAL {
331                let epitaph =
332                    decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
333                return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
334            }
335
336            header
337        };
338
339        if header.txid == 0 {
340            handler.on_event(*header.ordinal, header.flexibility(), Body::new(buffer)).await?;
341        } else {
342            let mut responses = self.inner.responses.lock().unwrap();
343            let locker = responses
344                .get(*header.txid - 1)
345                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid: *header.txid })?;
346
347            match locker.write(*header.ordinal, Body::new(buffer)) {
348                // Reader didn't cancel
349                Ok(false) => (),
350                // Reader canceled, we can drop the entry
351                Ok(true) => responses.free(*header.txid - 1),
352                Err(LockerError::NotWriteable) => {
353                    return Err(ProtocolError::UnrequestedResponse { txid: *header.txid });
354                }
355                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
356                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
357                }
358            }
359        }
360
361        Ok(())
362    }
363}