Skip to main content

fdomain_client/
lib.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fidl_fuchsia_fdomain as proto;
6use fidl_message::TransactionHeader;
7use fuchsia_sync::Mutex;
8use futures::FutureExt;
9use futures::channel::oneshot::Sender as OneshotSender;
10use futures::stream::Stream as StreamTrait;
11use std::collections::{HashMap, VecDeque};
12use std::convert::Infallible;
13use std::future::Future;
14use std::num::NonZeroU32;
15use std::pin::Pin;
16use std::sync::{Arc, LazyLock, Weak};
17use std::task::{Context, Poll, Waker, ready};
18
19mod channel;
20mod event;
21mod event_pair;
22mod handle;
23mod responder;
24mod socket;
25
26#[cfg(test)]
27mod test;
28
29pub mod fidl;
30pub mod fidl_next;
31
32use responder::Responder;
33
34pub use channel::{
35    AnyHandle, Channel, ChannelMessageStream, ChannelWriter, HandleInfo, HandleOp, MessageBuf,
36};
37pub use event::Event;
38pub use event_pair::Eventpair as EventPair;
39pub use handle::unowned::Unowned;
40pub use handle::{
41    AsHandleRef, Handle, HandleBased, HandleRef, NullableHandle, OnFDomainSignals, Peered,
42};
43pub use proto::{Error as FDomainError, WriteChannelError, WriteSocketError};
44pub use socket::{Socket, SocketDisposition, SocketReadStream, SocketWriter};
45
46// Unsupported handle types.
47#[rustfmt::skip]
48pub use Handle as Clock;
49#[rustfmt::skip]
50pub use Handle as Exception;
51#[rustfmt::skip]
52pub use Handle as Fifo;
53#[rustfmt::skip]
54pub use Handle as Iob;
55#[rustfmt::skip]
56pub use Handle as Job;
57#[rustfmt::skip]
58pub use Handle as Process;
59#[rustfmt::skip]
60pub use Handle as Resource;
61#[rustfmt::skip]
62pub use Handle as Stream;
63#[rustfmt::skip]
64pub use Handle as Thread;
65#[rustfmt::skip]
66pub use Handle as Vmar;
67#[rustfmt::skip]
68pub use Handle as Vmo;
69#[rustfmt::skip]
70pub use Handle as Counter;
71
72use proto::f_domain_ordinals as ordinals;
73
74fn write_fdomain_error(error: &FDomainError, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75    match error {
76        FDomainError::TargetError(e) => {
77            let e = zx_status::Status::from_raw(*e);
78            write!(f, "Target-side error {e}")
79        }
80        FDomainError::BadHandleId(proto::BadHandleId { id }) => {
81            write!(f, "Tried to use invalid handle id {id}")
82        }
83        FDomainError::WrongHandleType(proto::WrongHandleType { expected, got }) => write!(
84            f,
85            "Tried to use handle as {expected:?} but target reported handle was of type {got:?}"
86        ),
87        FDomainError::StreamingReadInProgress(proto::StreamingReadInProgress {}) => {
88            write!(f, "Handle is occupied delivering streaming reads")
89        }
90        FDomainError::NoReadInProgress(proto::NoReadInProgress {}) => {
91            write!(f, "No streaming read was in progress")
92        }
93        FDomainError::NewHandleIdOutOfRange(proto::NewHandleIdOutOfRange { id }) => {
94            write!(
95                f,
96                "Tried to create a handle with id {id}, which is outside the valid range for client handles"
97            )
98        }
99        FDomainError::NewHandleIdReused(proto::NewHandleIdReused { id, same_call }) => {
100            if *same_call {
101                write!(f, "Tried to create two or more new handles with the same id {id}")
102            } else {
103                write!(
104                    f,
105                    "Tried to create a new handle with id {id}, which is already the id of an existing handle"
106                )
107            }
108        }
109        FDomainError::WroteToSelf(proto::WroteToSelf {}) => {
110            write!(f, "Tried to write a channel into itself")
111        }
112        FDomainError::ClosedDuringRead(proto::ClosedDuringRead {}) => {
113            write!(f, "Handle closed while being read")
114        }
115        FDomainError::SignalsUnknown(signals_unknown) => {
116            write!(f, "Unknown signals: {:x}", signals_unknown.signals)
117        }
118        FDomainError::RightsUnknown(rights_unknown) => {
119            write!(f, "Unknown rights: {:x}", rights_unknown.rights)
120        }
121        FDomainError::SocketDispositionUnknown(socket_disposition_unknown) => {
122            write!(f, "Unknown socket disposition: {:?}", socket_disposition_unknown.disposition)
123        }
124        FDomainError::SocketTypeUnknown(socket_type_unknown) => {
125            write!(f, "Unknown socket type: {:?}", socket_type_unknown.type_)
126        }
127        e => write!(f, "Unknown FDomain error: {e:?}"),
128    }
129}
130
131/// Result type alias.
132pub type Result<T, E = Error> = std::result::Result<T, E>;
133
134/// Error type emitted by FDomain operations.
135#[derive(Clone)]
136pub enum Error {
137    SocketWrite(WriteSocketError),
138    ChannelWrite(WriteChannelError),
139    FDomain(FDomainError),
140    Protocol(::fidl::Error),
141    ProtocolObjectTypeIncompatible,
142    ProtocolRightsIncompatible,
143    ProtocolSignalsIncompatible,
144    ProtocolStreamEventIncompatible,
145    Transport(Option<Arc<std::io::Error>>),
146    ConnectionMismatch,
147    StreamingAborted,
148}
149
150impl std::fmt::Display for Error {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            Self::SocketWrite(proto::WriteSocketError { error, wrote }) => {
154                write!(f, "While writing socket (after {wrote} bytes written successfully): ")?;
155                write_fdomain_error(error, f)
156            }
157            Self::ChannelWrite(proto::WriteChannelError::Error(error)) => {
158                write!(f, "While writing channel: ")?;
159                write_fdomain_error(error, f)
160            }
161            Self::ChannelWrite(proto::WriteChannelError::OpErrors(errors)) => {
162                write!(f, "Couldn't write all handles into a channel:")?;
163                for (pos, error) in
164                    errors.iter().enumerate().filter_map(|(num, x)| x.as_ref().map(|y| (num, &**y)))
165                {
166                    write!(f, "\n  Handle in position {pos}: ")?;
167                    write_fdomain_error(error, f)?;
168                }
169                Ok(())
170            }
171            Self::ProtocolObjectTypeIncompatible => {
172                write!(
173                    f,
174                    "The FDomain protocol received an unrecognized or incompatible object type"
175                )
176            }
177            Self::ProtocolRightsIncompatible => {
178                write!(
179                    f,
180                    "The FDomain protocol received unrecognized or incompatible handle rights"
181                )
182            }
183            Self::ProtocolSignalsIncompatible => {
184                write!(f, "The FDomain protocol received unrecognized or incompatible signals")
185            }
186            Self::ProtocolStreamEventIncompatible => {
187                write!(
188                    f,
189                    "The FDomain protocol received an unrecognized or incompatible streaming IO event"
190                )
191            }
192            Self::FDomain(e) => write_fdomain_error(e, f),
193            Self::Protocol(e) => write!(f, "Protocol error: {e}"),
194            Self::Transport(Some(e)) => write!(f, "Transport error: {e}"),
195            Self::Transport(None) => {
196                write!(f, "Transport error: Connection to the device has been lost")
197            }
198            Self::ConnectionMismatch => {
199                write!(
200                    f,
201                    "Tried to use an FDomain handle with a different connection than the one it was created on"
202                )
203            }
204            Self::StreamingAborted => write!(f, "Streaming on this channel has been aborted"),
205        }
206    }
207}
208
209impl std::fmt::Debug for Error {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        match self {
212            Self::SocketWrite(e) => f.debug_tuple("SocketWrite").field(e).finish(),
213            Self::ChannelWrite(e) => f.debug_tuple("ChannelWrite").field(e).finish(),
214            Self::FDomain(e) => f.debug_tuple("FDomain").field(e).finish(),
215            Self::Protocol(e) => f.debug_tuple("Protocol").field(e).finish(),
216            Self::Transport(e) => f.debug_tuple("Transport").field(e).finish(),
217            Self::ProtocolObjectTypeIncompatible => write!(f, "ProtocolObjectTypeIncompatible "),
218            Self::ProtocolRightsIncompatible => write!(f, "ProtocolRightsIncompatible "),
219            Self::ProtocolSignalsIncompatible => write!(f, "ProtocolSignalsIncompatible "),
220            Self::ProtocolStreamEventIncompatible => write!(f, "ProtocolStreamEventIncompatible"),
221            Self::ConnectionMismatch => write!(f, "ConnectionMismatch"),
222            Self::StreamingAborted => write!(f, "StreamingAborted"),
223        }
224    }
225}
226
227impl std::error::Error for Error {}
228
229impl From<FDomainError> for Error {
230    fn from(other: FDomainError) -> Self {
231        Self::FDomain(other)
232    }
233}
234
235impl From<::fidl::Error> for Error {
236    fn from(other: ::fidl::Error) -> Self {
237        Self::Protocol(other)
238    }
239}
240
241impl From<WriteSocketError> for Error {
242    fn from(other: WriteSocketError) -> Self {
243        Self::SocketWrite(other)
244    }
245}
246
247impl From<WriteChannelError> for Error {
248    fn from(other: WriteChannelError) -> Self {
249        Self::ChannelWrite(other)
250    }
251}
252
253/// An error emitted internally by the client. Similar to [`Error`] but does not
254/// contain several variants which are irrelevant in the contexts where it is
255/// used.
256#[derive(Clone)]
257enum InnerError {
258    Protocol(::fidl::Error),
259    ProtocolStreamEventIncompatible,
260    Transport(Option<Arc<std::io::Error>>),
261}
262
263impl From<InnerError> for Error {
264    fn from(other: InnerError) -> Self {
265        match other {
266            InnerError::Protocol(p) => Error::Protocol(p),
267            InnerError::ProtocolStreamEventIncompatible => Error::ProtocolStreamEventIncompatible,
268            InnerError::Transport(t) => Error::Transport(t),
269        }
270    }
271}
272
273impl From<::fidl::Error> for InnerError {
274    fn from(other: ::fidl::Error) -> Self {
275        InnerError::Protocol(other)
276    }
277}
278
279// TODO(399717689) Figure out if we could just use AsyncRead/Write instead of a special trait.
280/// Implemented by objects which provide a transport over which we can speak the
281/// FDomain protocol.
282///
283/// The implementer must provide two things:
284/// 1) An incoming stream of messages presented as `Vec<u8>`. This is provided
285///    via the `Stream` trait, which this trait requires.
286/// 2) A way to send messages. This is provided by implementing the
287///    `poll_send_message` method.
288pub trait FDomainTransport: StreamTrait<Item = Result<Box<[u8]>, std::io::Error>> + Send {
289    /// Attempt to send a message asynchronously. Messages should be sent so
290    /// that they arrive at the target in order.
291    fn poll_send_message(
292        self: Pin<&mut Self>,
293        msg: &[u8],
294        ctx: &mut Context<'_>,
295    ) -> Poll<Result<(), Option<std::io::Error>>>;
296
297    /// Optional debug information outlet.
298    fn debug_fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        Ok(())
300    }
301
302    /// Whether `debug_fmt` does anything.
303    fn has_debug_fmt(&self) -> bool {
304        false
305    }
306}
307
308/// Wrapper for an `FDomainTransport` implementer that:
309/// 1) Provides a queue for outgoing messages so we need not have an await point
310///    when we submit a message.
311/// 2) Drops the transport on error, then returns the last observed error for
312///    all future operations.
313enum Transport {
314    Transport(Pin<Box<dyn FDomainTransport>>, VecDeque<Box<[u8]>>, Vec<Waker>),
315    Error(InnerError),
316}
317
318impl Transport {
319    /// Get the failure mode of the transport if it has failed.
320    fn error(&self) -> Option<InnerError> {
321        match self {
322            Transport::Transport(_, _, _) => None,
323            Transport::Error(inner_error) => Some(inner_error.clone()),
324        }
325    }
326
327    /// Enqueue a message to be sent on this transport.
328    fn push_msg(&mut self, msg: Box<[u8]>) -> Result<(), InnerError> {
329        match self {
330            Transport::Transport(_, v, w) => {
331                v.push_back(msg);
332                w.drain(..).for_each(Waker::wake);
333                Ok(())
334            }
335            Transport::Error(e) => Err(e.clone()),
336        }
337    }
338
339    /// Push messages in the send queue out through the transport.
340    fn poll_send_messages(&mut self, ctx: &mut Context<'_>) -> Poll<InnerError> {
341        match self {
342            Transport::Error(e) => Poll::Ready(e.clone()),
343            Transport::Transport(t, v, w) => {
344                while let Some(msg) = v.front() {
345                    match t.as_mut().poll_send_message(msg, ctx) {
346                        Poll::Ready(Ok(())) => {
347                            v.pop_front();
348                        }
349                        Poll::Ready(Err(e)) => {
350                            let e = e.map(Arc::new);
351                            return Poll::Ready(InnerError::Transport(e));
352                        }
353                        Poll::Pending => return Poll::Pending,
354                    }
355                }
356
357                if v.is_empty() {
358                    w.push(ctx.waker().clone());
359                } else {
360                    ctx.waker().wake_by_ref();
361                }
362                Poll::Pending
363            }
364        }
365    }
366
367    /// Get the next incoming message from the transport.
368    fn poll_next(&mut self, ctx: &mut Context<'_>) -> Poll<Result<Box<[u8]>, InnerError>> {
369        match self {
370            Transport::Error(e) => Poll::Ready(Err(e.clone())),
371            Transport::Transport(t, _, _) => match ready!(t.as_mut().poll_next(ctx)) {
372                Some(Ok(x)) => Poll::Ready(Ok(x)),
373                Some(Err(e)) => Poll::Ready(Err(InnerError::Transport(Some(Arc::new(e))))),
374                Option::None => Poll::Ready(Err(InnerError::Transport(None))),
375            },
376        }
377    }
378}
379
380impl Drop for Transport {
381    fn drop(&mut self) {
382        if let Transport::Transport(_, _, wakers) = self {
383            wakers.drain(..).for_each(Waker::wake);
384        }
385    }
386}
387
388/// State of a socket that is or has been read from.
389struct SocketReadState {
390    wakers: Vec<Waker>,
391    queued: VecDeque<Result<proto::SocketData, Error>>,
392    read_request_pending: bool,
393    is_streaming: bool,
394}
395
396impl SocketReadState {
397    /// Handle an incoming message, which is either a channel streaming event or
398    /// response to a `ChannelRead` request.
399    fn handle_incoming_message(&mut self, msg: Result<proto::SocketData, Error>) -> Vec<Waker> {
400        self.queued.push_back(msg);
401        std::mem::replace(&mut self.wakers, Vec::new())
402    }
403}
404
405/// State of a channel that is or has been read from.
406struct ChannelReadState {
407    wakers: Vec<Waker>,
408    queued: VecDeque<Result<proto::ChannelMessage, Error>>,
409    read_request_pending: bool,
410    is_streaming: bool,
411}
412
413impl ChannelReadState {
414    /// Handle an incoming message, which is either a channel streaming event or
415    /// response to a `ChannelRead` request.
416    fn handle_incoming_message(&mut self, msg: Result<proto::ChannelMessage, Error>) -> Vec<Waker> {
417        self.queued.push_back(msg);
418        std::mem::replace(&mut self.wakers, Vec::new())
419    }
420}
421
422/// Lock-protected interior of `Client`
423struct ClientInner {
424    transport: Transport,
425    transactions: HashMap<NonZeroU32, responder::Responder>,
426    channel_read_states: HashMap<proto::HandleId, ChannelReadState>,
427    socket_read_states: HashMap<proto::HandleId, SocketReadState>,
428    next_tx_id: u32,
429    waiting_to_close: Vec<proto::HandleId>,
430    waiting_to_close_waker: Waker,
431
432    /// There is a lock around `ClientInner`, and sometimes the FIDL bindings
433    /// give us wakers that want to do handle operations synchronously on wake,
434    /// which means we can double-take the lock if we wake a waker while we hold
435    /// it. This is a place to store wakers that we'd like to be woken as soon
436    /// as we're not holding that lock, to avoid these weird reentrancy issues.
437    wakers_to_wake: Vec<Waker>,
438}
439
440impl ClientInner {
441    /// Serialize and enqueue a new transaction, including header and transaction ID.
442    fn request<S: fidl_message::Body>(&mut self, ordinal: u64, request: S, responder: Responder) {
443        if ordinal != ordinals::CLOSE {
444            self.process_waiting_to_close();
445        }
446        let tx_id = self.next_tx_id;
447
448        let header = TransactionHeader::new(tx_id, ordinal, fidl_message::DynamicFlags::FLEXIBLE);
449        let msg = fidl_message::encode_message(header, request).expect("Could not encode request!");
450        self.next_tx_id += 1;
451        if let Err(e) = self.transport.push_msg(msg.into()) {
452            let _ = responder.handle(self, Err(e.into()));
453        } else {
454            assert!(
455                self.transactions.insert(tx_id.try_into().unwrap(), responder).is_none(),
456                "Allocated same tx id twice!"
457            );
458        }
459    }
460
461    fn process_waiting_to_close(&mut self) {
462        if !self.waiting_to_close.is_empty() {
463            let handles = std::mem::replace(&mut self.waiting_to_close, Vec::new());
464            // We've dropped the handle object. Nobody is going to wait to read
465            // the buffers anymore. This is a safe time to drop the read state.
466            for handle in &handles {
467                let _ = self.channel_read_states.remove(handle);
468                let _ = self.socket_read_states.remove(handle);
469            }
470            self.request(
471                ordinals::CLOSE,
472                proto::FDomainCloseRequest { handles },
473                Responder::Ignore,
474            );
475        }
476    }
477
478    /// Polls the underlying transport to ensure any incoming or outgoing
479    /// messages are processed as far as possible. Errors if the transport has failed.
480    fn try_poll_transport(
481        &mut self,
482        ctx: &mut Context<'_>,
483    ) -> Poll<Result<Infallible, InnerError>> {
484        self.process_waiting_to_close();
485
486        self.waiting_to_close_waker = ctx.waker().clone();
487
488        loop {
489            if let Poll::Ready(e) = self.transport.poll_send_messages(ctx) {
490                return Poll::Ready(Err(e));
491            }
492            let Poll::Ready(result) = self.transport.poll_next(ctx) else {
493                return Poll::Pending;
494            };
495            let data = result?;
496            let (header, data) = fidl_message::decode_transaction_header(&data)?;
497
498            let Some(tx_id) = NonZeroU32::new(header.tx_id) else {
499                let wakers = self.process_event(header, data)?;
500                self.wakers_to_wake.extend(wakers);
501                continue;
502            };
503
504            let tx = self.transactions.remove(&tx_id).ok_or(::fidl::Error::InvalidResponseTxid)?;
505            tx.handle(self, Ok((header, data)))?;
506        }
507    }
508
509    /// Process an incoming message that arose from an event rather than a transaction reply.
510    fn process_event(
511        &mut self,
512        header: TransactionHeader,
513        data: &[u8],
514    ) -> Result<Vec<Waker>, InnerError> {
515        match header.ordinal {
516            ordinals::ON_SOCKET_STREAMING_DATA => {
517                let msg = fidl_message::decode_message::<proto::SocketOnSocketStreamingDataRequest>(
518                    header, data,
519                )?;
520                let o =
521                    self.socket_read_states.entry(msg.handle).or_insert_with(|| SocketReadState {
522                        wakers: Vec::new(),
523                        queued: VecDeque::new(),
524                        is_streaming: false,
525                        read_request_pending: false,
526                    });
527                match msg.socket_message {
528                    proto::SocketMessage::Data(data) => Ok(o.handle_incoming_message(Ok(data))),
529                    proto::SocketMessage::Stopped(proto::AioStopped { error }) => {
530                        let ret = if let Some(error) = error {
531                            o.handle_incoming_message(Err(Error::FDomain(*error)))
532                        } else {
533                            Vec::new()
534                        };
535                        o.is_streaming = false;
536                        Ok(ret)
537                    }
538                    _ => Err(InnerError::ProtocolStreamEventIncompatible),
539                }
540            }
541            ordinals::ON_CHANNEL_STREAMING_DATA => {
542                let msg = fidl_message::decode_message::<
543                    proto::ChannelOnChannelStreamingDataRequest,
544                >(header, data)?;
545                let o = self.channel_read_states.entry(msg.handle).or_insert_with(|| {
546                    ChannelReadState {
547                        wakers: Vec::new(),
548                        queued: VecDeque::new(),
549                        is_streaming: false,
550                        read_request_pending: false,
551                    }
552                });
553                match msg.channel_sent {
554                    proto::ChannelSent::Message(data) => Ok(o.handle_incoming_message(Ok(data))),
555                    proto::ChannelSent::Stopped(proto::AioStopped { error }) => {
556                        let ret = if let Some(error) = error {
557                            o.handle_incoming_message(Err(Error::FDomain(*error)))
558                        } else {
559                            Vec::new()
560                        };
561                        o.is_streaming = false;
562                        Ok(ret)
563                    }
564                    _ => Err(InnerError::ProtocolStreamEventIncompatible),
565                }
566            }
567            _ => Err(::fidl::Error::UnknownOrdinal {
568                ordinal: header.ordinal,
569                protocol_name:
570                    <proto::FDomainMarker as ::fidl::endpoints::ProtocolMarker>::DEBUG_NAME,
571            }
572            .into()),
573        }
574    }
575
576    /// Polls the underlying transport to ensure any incoming or outgoing
577    /// messages are processed as far as possible. If a failure occurs, puts the
578    /// transport into an error state and fails all pending transactions.
579    fn poll_transport(&mut self, ctx: &mut Context<'_>) -> Poll<()> {
580        if let Poll::Ready(Err(e)) = self.try_poll_transport(ctx) {
581            for (_, v) in std::mem::take(&mut self.transactions) {
582                let _ = v.handle(self, Err(e.clone()));
583            }
584            for mut state in std::mem::take(&mut self.socket_read_states).into_values() {
585                state.queued.push_back(Err(Error::from(e.clone())));
586                self.wakers_to_wake.extend(state.wakers);
587            }
588            for (_, mut state) in self.channel_read_states.drain() {
589                state.queued.push_back(Err(Error::from(e.clone())));
590                self.wakers_to_wake.extend(state.wakers);
591            }
592            if matches!(self.transport, Transport::Transport(_, _, _)) {
593                self.transport = Transport::Error(e);
594            }
595
596            Poll::Ready(())
597        } else {
598            Poll::Pending
599        }
600    }
601
602    /// Handles the response to a `SocketRead` protocol message.
603    pub(crate) fn handle_socket_read_response(
604        &mut self,
605        msg: Result<proto::SocketData, Error>,
606        id: proto::HandleId,
607    ) {
608        let state = self.socket_read_states.entry(id).or_insert_with(|| SocketReadState {
609            wakers: Vec::new(),
610            queued: VecDeque::new(),
611            is_streaming: false,
612            read_request_pending: false,
613        });
614        let wakers = state.handle_incoming_message(msg);
615        self.wakers_to_wake.extend(wakers);
616        state.read_request_pending = false;
617    }
618
619    /// Handles the response to a `ChannelRead` protocol message.
620    pub(crate) fn handle_channel_read_response(
621        &mut self,
622        msg: Result<proto::ChannelMessage, Error>,
623        id: proto::HandleId,
624    ) {
625        let state = self.channel_read_states.entry(id).or_insert_with(|| ChannelReadState {
626            wakers: Vec::new(),
627            queued: VecDeque::new(),
628            is_streaming: false,
629            read_request_pending: false,
630        });
631        let wakers = state.handle_incoming_message(msg);
632        self.wakers_to_wake.extend(wakers);
633        state.read_request_pending = false;
634    }
635}
636
637impl Drop for ClientInner {
638    fn drop(&mut self) {
639        let responders = self.transactions.drain().map(|x| x.1).collect::<Vec<_>>();
640        for responder in responders {
641            let _ = responder.handle(self, Err(InnerError::Transport(None)));
642        }
643        for state in self.channel_read_states.values_mut() {
644            state.wakers.drain(..).for_each(Waker::wake);
645        }
646        for state in self.socket_read_states.values_mut() {
647            state.wakers.drain(..).for_each(Waker::wake);
648        }
649        self.waiting_to_close_waker.wake_by_ref();
650        self.wakers_to_wake.drain(..).for_each(Waker::wake);
651    }
652}
653
654/// Represents a connection to an FDomain.
655///
656/// The client is constructed by passing it a transport object which represents
657/// the raw connection to the remote FDomain. The `Client` wrapper then allows
658/// us to construct and use handles which behave similarly to their counterparts
659/// on a Fuchsia device.
660pub struct Client(pub(crate) Mutex<ClientInner>);
661
662impl std::fmt::Debug for Client {
663    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664        let inner = self.0.lock();
665        match &inner.transport {
666            Transport::Transport(transport, ..) if transport.has_debug_fmt() => {
667                write!(f, "Client(")?;
668                transport.debug_fmt(f)?;
669                write!(f, ")")
670            }
671            Transport::Error(error) => {
672                let error = Error::from(error.clone());
673                write!(f, "Client(Failed: {error})")
674            }
675            _ => f.debug_tuple("Client").field(&"<transport>").finish(),
676        }
677    }
678}
679
680/// A client which is always disconnected. Handles that lose their clients
681/// connect to this client instead, which always returns a "Client Lost"
682/// transport failure.
683pub(crate) static DEAD_CLIENT: LazyLock<Arc<Client>> = LazyLock::new(|| {
684    Arc::new(Client(Mutex::new(ClientInner {
685        transport: Transport::Error(InnerError::Transport(None)),
686        transactions: HashMap::new(),
687        channel_read_states: HashMap::new(),
688        socket_read_states: HashMap::new(),
689        next_tx_id: 1,
690        waiting_to_close: Vec::new(),
691        waiting_to_close_waker: std::task::Waker::noop().clone(),
692        wakers_to_wake: Vec::new(),
693    })))
694});
695
696/// A wrapper around the FDomain client background future that ensures
697/// all pending transactions and reads are failed if the loop is dropped.
698///
699/// This prevents hangs when the transport is abruptly closed (e.g. during target reboot)
700/// by waking up any futures waiting for responses or data on channels/sockets.
701pub struct ClientLoop {
702    client: Weak<Client>,
703    fut: Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
704}
705
706impl Future for ClientLoop {
707    type Output = ();
708    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
709        self.fut.as_mut().poll(cx)
710    }
711}
712
713impl Drop for ClientLoop {
714    fn drop(&mut self) {
715        let Some(client) = self.client.upgrade() else {
716            return;
717        };
718
719        let (channel_read_states, socket_read_states, deferred_wakers) = {
720            let mut inner = client.0.lock();
721            let transactions = std::mem::take(&mut inner.transactions);
722            log::debug!("ClientLoop dropped, failing {} transactions", transactions.len());
723            for (_, v) in transactions {
724                let _ = v.handle(&mut *inner, Err(InnerError::Transport(None)));
725            }
726
727            let channel_read_states = std::mem::take(&mut inner.channel_read_states);
728            let socket_read_states = std::mem::take(&mut inner.socket_read_states);
729
730            let deferred_wakers = std::mem::replace(&mut inner.wakers_to_wake, Vec::new());
731
732            (channel_read_states, socket_read_states, deferred_wakers)
733        };
734
735        log::debug!("Failing reads on {} channels", channel_read_states.len());
736        for (_, mut state) in channel_read_states {
737            state.queued.push_back(Err(Error::Transport(None)));
738            state.wakers.into_iter().for_each(Waker::wake);
739        }
740
741        log::debug!("Failing reads on {} sockets", socket_read_states.len());
742        for (_, mut state) in socket_read_states {
743            state.queued.push_back(Err(Error::Transport(None)));
744            state.wakers.into_iter().for_each(Waker::wake);
745        }
746
747        deferred_wakers.into_iter().for_each(Waker::wake);
748    }
749}
750
751impl Client {
752    pub fn transport_status(&self) -> Result<()> {
753        match &self.0.lock().transport {
754            Transport::Error(e) => Err(e.clone().into()),
755            Transport::Transport(_, _, _) => Ok(()),
756        }
757    }
758    /// Create a new FDomain client. The `transport` argument should contain the
759    /// established connection to the target, ready to communicate the FDomain
760    /// protocol.
761    ///
762    /// The second return item is a future that must be polled to keep
763    /// transactions running.
764    pub fn new(
765        transport: impl FDomainTransport + 'static,
766    ) -> (Arc<Self>, impl Future<Output = ()> + Send + 'static) {
767        let ret = Arc::new(Client(Mutex::new(ClientInner {
768            transport: Transport::Transport(Box::pin(transport), VecDeque::new(), Vec::new()),
769            transactions: HashMap::new(),
770            socket_read_states: HashMap::new(),
771            channel_read_states: HashMap::new(),
772            next_tx_id: 1,
773            waiting_to_close: Vec::new(),
774            waiting_to_close_waker: std::task::Waker::noop().clone(),
775            wakers_to_wake: Vec::new(),
776        })));
777
778        let client_weak = Arc::downgrade(&ret);
779        let fut = futures::future::poll_fn(move |ctx| {
780            let Some(client) = client_weak.upgrade() else {
781                return Poll::Ready(());
782            };
783
784            let (ret, deferred_wakers) = {
785                let mut inner = client.0.lock();
786                let ret = inner.poll_transport(ctx);
787                let deferred_wakers = std::mem::replace(&mut inner.wakers_to_wake, Vec::new());
788                (ret, deferred_wakers)
789            };
790            deferred_wakers.into_iter().for_each(Waker::wake);
791            ret
792        });
793
794        let client_loop = ClientLoop { client: Arc::downgrade(&ret), fut: Box::pin(fut) };
795
796        (ret, client_loop)
797    }
798
799    /// Get the namespace for the connected FDomain. Calling this more than once is an error.
800    pub async fn namespace(self: &Arc<Self>) -> Result<Channel, Error> {
801        let new_handle = self.new_hid();
802        self.transaction(
803            ordinals::GET_NAMESPACE,
804            proto::FDomainGetNamespaceRequest { new_handle },
805            Responder::Namespace,
806        )
807        .await?;
808        Ok(Channel(Handle { id: new_handle.id, client: Arc::downgrade(self) }))
809    }
810
811    /// Create a new channel in the connected FDomain.
812    pub fn create_channel(self: &Arc<Self>) -> (Channel, Channel) {
813        let id_a = self.new_hid();
814        let id_b = self.new_hid();
815        let fut = self.transaction(
816            ordinals::CREATE_CHANNEL,
817            proto::ChannelCreateChannelRequest { handles: [id_a, id_b] },
818            Responder::CreateChannel,
819        );
820
821        fuchsia_async::Task::spawn(async move {
822            if let Err(e) = fut.await {
823                log::debug!("FDomain channel creation failed: {e}");
824            }
825        })
826        .detach();
827
828        (
829            Channel(Handle { id: id_a.id, client: Arc::downgrade(self) }),
830            Channel(Handle { id: id_b.id, client: Arc::downgrade(self) }),
831        )
832    }
833
834    /// Creates client and server endpoints connected to by a channel.
835    pub fn create_endpoints<F: crate::fidl::ProtocolMarker>(
836        self: &Arc<Self>,
837    ) -> (crate::fidl::ClientEnd<F>, crate::fidl::ServerEnd<F>) {
838        let (client, server) = self.create_channel();
839        let client_end = crate::fidl::ClientEnd::<F>::new(client);
840        let server_end = crate::fidl::ServerEnd::new(server);
841        (client_end, server_end)
842    }
843
844    /// Creates a client proxy and a server endpoint connected by a channel.
845    pub fn create_proxy<F: crate::fidl::ProtocolMarker>(
846        self: &Arc<Self>,
847    ) -> (F::Proxy, crate::fidl::ServerEnd<F>) {
848        let (client_end, server_end) = self.create_endpoints::<F>();
849        (client_end.into_proxy(), server_end)
850    }
851
852    /// Creates a client proxy and a server request stream connected by a channel.
853    pub fn create_proxy_and_stream<F: crate::fidl::ProtocolMarker>(
854        self: &Arc<Self>,
855    ) -> (F::Proxy, F::RequestStream) {
856        let (client_end, server_end) = self.create_endpoints::<F>();
857        (client_end.into_proxy(), server_end.into_stream())
858    }
859
860    /// Creates a client end and a server request stream connected by a channel.
861    pub fn create_request_stream<F: crate::fidl::ProtocolMarker>(
862        self: &Arc<Self>,
863    ) -> (crate::fidl::ClientEnd<F>, F::RequestStream) {
864        let (client_end, server_end) = self.create_endpoints::<F>();
865        (client_end, server_end.into_stream())
866    }
867
868    /// Create a new socket in the connected FDomain.
869    fn create_socket(self: &Arc<Self>, options: proto::SocketType) -> (Socket, Socket) {
870        let id_a = self.new_hid();
871        let id_b = self.new_hid();
872        let fut = self.transaction(
873            ordinals::CREATE_SOCKET,
874            proto::SocketCreateSocketRequest { handles: [id_a, id_b], options },
875            Responder::CreateSocket,
876        );
877
878        fuchsia_async::Task::spawn(async move {
879            if let Err(e) = fut.await {
880                log::debug!("FDomain socket creation failed: {e}");
881            }
882        })
883        .detach();
884
885        (
886            Socket(Handle { id: id_a.id, client: Arc::downgrade(self) }),
887            Socket(Handle { id: id_b.id, client: Arc::downgrade(self) }),
888        )
889    }
890
891    /// Create a new streaming socket in the connected FDomain.
892    pub fn create_stream_socket(self: &Arc<Self>) -> (Socket, Socket) {
893        self.create_socket(proto::SocketType::Stream)
894    }
895
896    /// Create a new datagram socket in the connected FDomain.
897    pub fn create_datagram_socket(self: &Arc<Self>) -> (Socket, Socket) {
898        self.create_socket(proto::SocketType::Datagram)
899    }
900
901    /// Create a new event pair in the connected FDomain.
902    pub fn create_event_pair(self: &Arc<Self>) -> (EventPair, EventPair) {
903        let id_a = self.new_hid();
904        let id_b = self.new_hid();
905        let fut = self.transaction(
906            ordinals::CREATE_EVENT_PAIR,
907            proto::EventPairCreateEventPairRequest { handles: [id_a, id_b] },
908            Responder::CreateEventPair,
909        );
910
911        fuchsia_async::Task::spawn(async move {
912            if let Err(e) = fut.await {
913                log::debug!("FDomain event pair creation failed: {e}");
914            }
915        })
916        .detach();
917
918        (
919            EventPair(Handle { id: id_a.id, client: Arc::downgrade(self) }),
920            EventPair(Handle { id: id_b.id, client: Arc::downgrade(self) }),
921        )
922    }
923
924    /// Create a new event handle in the connected FDomain.
925    pub fn create_event(self: &Arc<Self>) -> Event {
926        let id = self.new_hid();
927        let fut = self.transaction(
928            ordinals::CREATE_EVENT,
929            proto::EventCreateEventRequest { handle: id },
930            Responder::CreateEvent,
931        );
932
933        fuchsia_async::Task::spawn(async move {
934            if let Err(e) = fut.await {
935                log::debug!("FDomain event creation failed: {e}");
936            }
937        })
938        .detach();
939
940        Event(Handle { id: id.id, client: Arc::downgrade(self) })
941    }
942
943    /// Allocate a new HID, which should be suitable for use with the connected FDomain.
944    pub(crate) fn new_hid(&self) -> proto::NewHandleId {
945        // TODO: On the target side we have to keep a table of these which means
946        // we can automatically detect collisions in the random value. On the
947        // client side we'd have to add a whole data structure just for that
948        // purpose. Should we?
949        proto::NewHandleId { id: rand::random::<u32>() >> 1 }
950    }
951
952    /// Create a future which sends a FIDL message to the connected FDomain and
953    /// waits for a response.
954    ///
955    /// Calling this method queues the transaction synchronously. Awaiting is
956    /// only necessary to wait for the response.
957    pub(crate) fn transaction<S: fidl_message::Body, R: 'static, F>(
958        self: &Arc<Self>,
959        ordinal: u64,
960        request: S,
961        f: F,
962    ) -> impl Future<Output = Result<R, Error>> + 'static + use<S, R, F>
963    where
964        F: Fn(OneshotSender<Result<R, Error>>) -> Responder,
965    {
966        let mut inner = self.0.lock();
967
968        let (sender, receiver) = futures::channel::oneshot::channel();
969        inner.request(ordinal, request, f(sender));
970        receiver.map(|x| x.expect("Oneshot went away without reply!"))
971    }
972
973    /// Start getting streaming events for socket reads.
974    pub(crate) fn start_socket_streaming(&self, id: proto::HandleId) -> Result<(), Error> {
975        let mut inner = self.0.lock();
976        if let Some(e) = inner.transport.error() {
977            return Err(e.into());
978        }
979
980        let state = inner.socket_read_states.entry(id).or_insert_with(|| SocketReadState {
981            wakers: Vec::new(),
982            queued: VecDeque::new(),
983            is_streaming: false,
984            read_request_pending: false,
985        });
986
987        assert!(!state.is_streaming, "Initiated streaming twice!");
988        state.is_streaming = true;
989
990        inner.request(
991            ordinals::READ_SOCKET_STREAMING_START,
992            proto::SocketReadSocketStreamingStartRequest { handle: id },
993            Responder::Ignore,
994        );
995        Ok(())
996    }
997
998    /// Stop getting streaming events for socket reads. Doesn't return errors
999    /// because it's exclusively called in destructors where we have nothing to
1000    /// do with them.
1001    pub(crate) fn stop_socket_streaming(&self, id: proto::HandleId) {
1002        let mut inner = self.0.lock();
1003        if let Some(state) = inner.socket_read_states.get_mut(&id) {
1004            if state.is_streaming {
1005                state.is_streaming = false;
1006                // TODO: Log?
1007                let _ = inner.request(
1008                    ordinals::READ_SOCKET_STREAMING_STOP,
1009                    proto::ChannelReadChannelStreamingStopRequest { handle: id },
1010                    Responder::Ignore,
1011                );
1012            }
1013        }
1014    }
1015
1016    /// Start getting streaming events for socket reads.
1017    pub(crate) fn start_channel_streaming(&self, id: proto::HandleId) -> Result<(), Error> {
1018        let mut inner = self.0.lock();
1019        if let Some(e) = inner.transport.error() {
1020            return Err(e.into());
1021        }
1022        let state = inner.channel_read_states.entry(id).or_insert_with(|| ChannelReadState {
1023            wakers: Vec::new(),
1024            queued: VecDeque::new(),
1025            is_streaming: false,
1026            read_request_pending: false,
1027        });
1028
1029        assert!(!state.is_streaming, "Initiated streaming twice!");
1030        state.is_streaming = true;
1031
1032        inner.request(
1033            ordinals::READ_CHANNEL_STREAMING_START,
1034            proto::ChannelReadChannelStreamingStartRequest { handle: id },
1035            Responder::Ignore,
1036        );
1037
1038        Ok(())
1039    }
1040
1041    /// Stop getting streaming events for socket reads. Doesn't return errors
1042    /// because it's exclusively called in destructors where we have nothing to
1043    /// do with them.
1044    pub(crate) fn stop_channel_streaming(&self, id: proto::HandleId) {
1045        let mut inner = self.0.lock();
1046        if let Some(state) = inner.channel_read_states.get_mut(&id) {
1047            if state.is_streaming {
1048                state.is_streaming = false;
1049                // TODO: Log?
1050                let _ = inner.request(
1051                    ordinals::READ_CHANNEL_STREAMING_STOP,
1052                    proto::ChannelReadChannelStreamingStopRequest { handle: id },
1053                    Responder::Ignore,
1054                );
1055            }
1056        }
1057    }
1058
1059    /// Execute a read from a channel.
1060    pub(crate) fn poll_socket(
1061        &self,
1062        id: proto::HandleId,
1063        ctx: &mut Context<'_>,
1064        out: &mut [u8],
1065    ) -> Poll<Result<usize, Error>> {
1066        let mut inner = self.0.lock();
1067        if let Some(error) = inner.transport.error() {
1068            return Poll::Ready(Err(error.into()));
1069        }
1070
1071        let state = inner.socket_read_states.entry(id).or_insert_with(|| SocketReadState {
1072            wakers: Vec::new(),
1073            queued: VecDeque::new(),
1074            is_streaming: false,
1075            read_request_pending: false,
1076        });
1077
1078        if let Some(got) = state.queued.front_mut() {
1079            match got.as_mut() {
1080                Ok(data) => {
1081                    let read_size = std::cmp::min(data.data.len(), out.len());
1082                    out[..read_size].copy_from_slice(&data.data[..read_size]);
1083
1084                    if data.data.len() > read_size && !data.is_datagram {
1085                        let _ = data.data.drain(..read_size);
1086                    } else {
1087                        let _ = state.queued.pop_front();
1088                    }
1089
1090                    return Poll::Ready(Ok(read_size));
1091                }
1092                Err(_) => {
1093                    let err = state.queued.pop_front().unwrap().unwrap_err();
1094                    return Poll::Ready(Err(err));
1095                }
1096            }
1097        } else if !state.wakers.iter().any(|x| ctx.waker().will_wake(x)) {
1098            state.wakers.push(ctx.waker().clone());
1099        }
1100
1101        if !state.read_request_pending && !state.is_streaming {
1102            inner.request(
1103                ordinals::READ_SOCKET,
1104                proto::SocketReadSocketRequest { handle: id, max_bytes: out.len() as u64 },
1105                Responder::ReadSocket(id),
1106            );
1107        }
1108
1109        Poll::Pending
1110    }
1111
1112    /// Execute a read from a channel.
1113    pub(crate) fn poll_channel(
1114        &self,
1115        id: proto::HandleId,
1116        ctx: &mut Context<'_>,
1117        for_stream: bool,
1118    ) -> Poll<Option<Result<proto::ChannelMessage, Error>>> {
1119        let mut inner = self.0.lock();
1120        if let Some(error) = inner.transport.error() {
1121            return Poll::Ready(Some(Err(error.into())));
1122        }
1123
1124        let state = inner.channel_read_states.entry(id).or_insert_with(|| ChannelReadState {
1125            wakers: Vec::new(),
1126            queued: VecDeque::new(),
1127            is_streaming: false,
1128            read_request_pending: false,
1129        });
1130
1131        if let Some(got) = state.queued.pop_front() {
1132            return Poll::Ready(Some(got));
1133        } else if for_stream && !state.is_streaming {
1134            return Poll::Ready(None);
1135        } else if !state.wakers.iter().any(|x| ctx.waker().will_wake(x)) {
1136            state.wakers.push(ctx.waker().clone());
1137        }
1138
1139        if !state.read_request_pending && !state.is_streaming {
1140            inner.request(
1141                ordinals::READ_CHANNEL,
1142                proto::ChannelReadChannelRequest { handle: id },
1143                Responder::ReadChannel(id),
1144            );
1145        }
1146
1147        Poll::Pending
1148    }
1149
1150    /// Check whether this channel is streaming
1151    pub(crate) fn channel_is_streaming(&self, id: proto::HandleId) -> bool {
1152        let inner = self.0.lock();
1153        let Some(state) = inner.channel_read_states.get(&id) else {
1154            return false;
1155        };
1156        state.is_streaming
1157    }
1158
1159    /// Check that all the given handles are safe to transfer through a channel
1160    /// e.g. that there's no chance of in-flight reads getting dropped.
1161    pub(crate) fn clear_handles_for_transfer(&self, handles: &proto::Handles) {
1162        let inner = self.0.lock();
1163        match handles {
1164            proto::Handles::Handles(handles) => {
1165                for handle in handles {
1166                    assert!(
1167                        !(inner.channel_read_states.contains_key(handle)
1168                            || inner.socket_read_states.contains_key(handle)),
1169                        "Tried to transfer handle after reading"
1170                    );
1171                }
1172            }
1173            proto::Handles::Dispositions(dispositions) => {
1174                for disposition in dispositions {
1175                    match &disposition.handle {
1176                        proto::HandleOp::Move_(handle) => assert!(
1177                            !(inner.channel_read_states.contains_key(handle)
1178                                || inner.socket_read_states.contains_key(handle)),
1179                            "Tried to transfer handle after reading"
1180                        ),
1181                        // Pretty sure this should be fine regardless of read state.
1182                        proto::HandleOp::Duplicate(_) => (),
1183                    }
1184                }
1185            }
1186        }
1187    }
1188}