fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_next_codec::{Encode, EncodeError, EncoderExt};
12
13use crate::concurrency::sync::{Arc, Mutex};
14
15use crate::{ProtocolError, Transport, decode_epitaph, decode_header, encode_header};
16
17use super::connection::{Connection, ORDINAL_EPITAPH, SendFuture};
18use super::lockers::{LockerError, Lockers};
19
20struct ClientSenderInner<T: Transport> {
21    connection: Connection<T>,
22    responses: Mutex<Lockers<T::RecvBuffer>>,
23}
24
25impl<T: Transport> ClientSenderInner<T> {
26    fn new(shared: T::Shared) -> Self {
27        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
28    }
29}
30
31/// A sender for a client endpoint.
32pub struct ClientSender<T: Transport> {
33    inner: Arc<ClientSenderInner<T>>,
34}
35
36impl<T: Transport> Drop for ClientSender<T> {
37    fn drop(&mut self) {
38        if Arc::strong_count(&self.inner) == 2 {
39            // This was the last client sender other than the one in the client
40            // itself. Stop the connection.
41            self.close();
42        }
43    }
44}
45
46impl<T: Transport> ClientSender<T> {
47    /// Closes the channel from the client end.
48    pub fn close(&self) {
49        self.inner.connection.stop();
50    }
51
52    /// Send a request.
53    pub fn send_one_way<M>(
54        &self,
55        ordinal: u64,
56        request: M,
57    ) -> Result<SendFuture<'_, T>, EncodeError>
58    where
59        M: Encode<T::SendBuffer>,
60    {
61        self.send_message(0, ordinal, request)
62    }
63
64    /// Send a request and await for a response.
65    pub fn send_two_way<M>(
66        &self,
67        ordinal: u64,
68        request: M,
69    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
70    where
71        M: Encode<T::SendBuffer>,
72    {
73        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
74
75        // Send with txid = index + 1 because indices start at 0.
76        match self.send_message(index + 1, ordinal, request) {
77            Ok(send_future) => {
78                Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
79            }
80            Err(e) => {
81                self.inner.responses.lock().unwrap().free(index);
82                Err(e)
83            }
84        }
85    }
86
87    fn send_message<M>(
88        &self,
89        txid: u32,
90        ordinal: u64,
91        message: M,
92    ) -> Result<SendFuture<'_, T>, EncodeError>
93    where
94        M: Encode<T::SendBuffer>,
95    {
96        self.inner.connection.send_with(|buffer| {
97            encode_header::<T>(buffer, txid, ordinal)?;
98            buffer.encode_next(message)
99        })
100    }
101}
102
103impl<T: Transport> Clone for ClientSender<T> {
104    fn clone(&self) -> Self {
105        Self { inner: self.inner.clone() }
106    }
107}
108
109/// A future for a pending response to a two-way message.
110pub struct TwoWayResponseFuture<'a, T: Transport> {
111    inner: &'a ClientSenderInner<T>,
112    index: Option<u32>,
113}
114
115impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
116    fn drop(&mut self) {
117        // If `index` is `Some`, then we still need to free our locker.
118        if let Some(index) = self.index {
119            let mut responses = self.inner.responses.lock().unwrap();
120            if responses.get(index).unwrap().cancel() {
121                responses.free(index);
122            }
123        }
124    }
125}
126
127impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
128    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
129
130    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
131        let this = Pin::into_inner(self);
132        let Some(index) = this.index else {
133            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
134        };
135
136        let mut responses = this.inner.responses.lock().unwrap();
137        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
138            Ok(ready)
139        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
140            Err(termination_reason)
141        } else {
142            return Poll::Pending;
143        };
144
145        responses.free(index);
146        this.index = None;
147        Poll::Ready(ready)
148    }
149}
150
151/// A future for a sending a two-way FIDL message.
152pub struct TwoWayRequestFuture<'a, T: Transport> {
153    inner: &'a ClientSenderInner<T>,
154    index: Option<u32>,
155    send_future: SendFuture<'a, T>,
156}
157
158impl<T: Transport> Drop for TwoWayRequestFuture<'_, T> {
159    fn drop(&mut self) {
160        if let Some(index) = self.index {
161            let mut responses = self.inner.responses.lock().unwrap();
162
163            // The future was canceled before it could be sent. The transaction
164            // ID was never used, so it's safe to immediately reuse.
165            responses.free(index);
166        }
167    }
168}
169
170impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
171    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
172
173    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174        // SAFETY: We treat the state as pinned as long as it is sending.
175        let this = unsafe { Pin::into_inner_unchecked(self) };
176
177        let Some(index) = this.index else {
178            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
179        };
180
181        // SAFETY: `send_future` is structurally pinned.
182        let send_future = unsafe { Pin::new_unchecked(&mut this.send_future) };
183
184        let result = ready!(send_future.poll(cx));
185        this.index = None;
186        if let Err(error) = result {
187            // The send failed. Free the locker and return an error.
188            this.inner.responses.lock().unwrap().free(index);
189            Poll::Ready(Err(error))
190        } else {
191            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
192        }
193    }
194}
195
196/// A type which handles incoming events for a client.
197pub trait ClientHandler<T: Transport> {
198    /// Handles a received client event.
199    ///
200    /// The client cannot handle more messages until `on_event` completes. If `on_event` should
201    /// handle requests in parallel, it should spawn a new async task and return.
202    fn on_event(
203        &mut self,
204        sender: &ClientSender<T>,
205        ordinal: u64,
206        buffer: T::RecvBuffer,
207    ) -> impl Future<Output = ()> + Send;
208}
209
210/// A client for an endpoint.
211///
212/// It must be actively polled to receive events and two-way message responses.
213pub struct Client<T: Transport> {
214    sender: ClientSender<T>,
215    exclusive: T::Exclusive,
216    is_terminated: bool,
217}
218
219impl<T: Transport> Drop for Client<T> {
220    fn drop(&mut self) {
221        if !self.is_terminated {
222            // SAFETY: We checked that the connection has not been terminated.
223            unsafe {
224                self.terminate(ProtocolError::Stopped);
225            }
226        }
227    }
228}
229
230impl<T: Transport> Client<T> {
231    /// Creates a new client from a transport.
232    pub fn new(transport: T) -> Self {
233        let (shared, exclusive) = transport.split();
234        let inner = Arc::new(ClientSenderInner::new(shared));
235        Self { sender: ClientSender { inner }, exclusive, is_terminated: false }
236    }
237
238    /// # Safety
239    ///
240    /// The connection must not yet be terminated.
241    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
242        // SAFETY: We checked that the connection has not been terminated.
243        unsafe {
244            self.sender.inner.connection.terminate(error);
245        }
246        self.sender.inner.responses.lock().unwrap().wake_all();
247    }
248
249    /// Returns the sender for the client.
250    pub fn sender(&self) -> &ClientSender<T> {
251        &self.sender
252    }
253
254    /// Runs the client with the provided handler.
255    pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
256    where
257        H: ClientHandler<T>,
258    {
259        // We may assume that the connection has not been terminated because
260        // connections are only terminated by `run` and `drop`. Neither of those
261        // could have been called before this method because `run` consumes
262        // `self` and `drop` is only ever called once.
263
264        let error = loop {
265            // SAFETY: The connection has not been terminated.
266            let result = unsafe { self.run_one(&mut handler).await };
267            if let Err(error) = result {
268                break error;
269            }
270        };
271
272        // SAFETY: The connection has not been terminated.
273        unsafe {
274            self.terminate(error.clone());
275        }
276        self.is_terminated = true;
277
278        match error {
279            // We consider clients to have finished successfully only if they
280            // stop themselves manually.
281            ProtocolError::Stopped => Ok(handler),
282
283            // Otherwise, the client finished with an error.
284            _ => Err(error),
285        }
286    }
287
288    /// # Safety
289    ///
290    /// The connection must not be terminated.
291    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
292    where
293        H: ClientHandler<T>,
294    {
295        // SAFETY: The caller guaranteed that the connection is not terminated.
296        let mut buffer = unsafe { self.sender.inner.connection.recv(&mut self.exclusive).await? };
297
298        let (txid, ordinal) =
299            decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
300
301        if ordinal == ORDINAL_EPITAPH {
302            let epitaph =
303                decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
304            return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
305        } else if txid == 0 {
306            handler.on_event(&self.sender, ordinal, buffer).await;
307        } else {
308            let mut responses = self.sender.inner.responses.lock().unwrap();
309            let locker = responses
310                .get(txid - 1)
311                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
312
313            match locker.write(ordinal, buffer) {
314                // Reader didn't cancel
315                Ok(false) => (),
316                // Reader canceled, we can drop the entry
317                Ok(true) => responses.free(txid - 1),
318                Err(LockerError::NotWriteable) => {
319                    return Err(ProtocolError::UnrequestedResponse { txid });
320                }
321                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
322                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
323                }
324            }
325        }
326
327        Ok(())
328    }
329
330    /// Runs the client with the [`IgnoreEvents`] handler.
331    pub async fn run_sender(self) -> Result<(), ProtocolError<T::Error>> {
332        self.run(IgnoreEvents).await.map(|_| ())
333    }
334}
335
336/// A client handler which ignores any incoming events.
337pub struct IgnoreEvents;
338
339impl<T: Transport> ClientHandler<T> for IgnoreEvents {
340    async fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
341}