fidl_contrib/
protocol_connector.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides libs for connecting to and interacting with a FIDL protocol.
6//!
7//! If you have a fidl protocol like this:
8//!
9//! ```fidl
10//! type Error = strict enum : int32 {
11//!     PERMANENT = 1;
12//!     TRANSIENT = 2;
13//! };
14//!
15//! @discoverable
16//! protocol ProtocolFactory {
17//!     CreateProtocol(resource struct {
18//!         protocol server_end:Protocol;
19//!     }) -> () error Error;
20//! };
21//!
22//! protocol Protocol {
23//!     DoAction() -> () error Error;
24//! };
25//! ```
26//!
27//! Then you could implement ConnectedProtocol as follows:
28//!
29//! ```rust
30//! struct ProtocolConnectedProtocol;
31//! impl ConnectedProtocol for ProtocolConnectedProtocol {
32//!     type Protocol = ProtocolProxy;
33//!     type ConnectError = anyhow::Error;
34//!     type Message = ();
35//!     type SendError = anyhow::Error;
36//!
37//!     fn get_protocol<'a>(
38//!         &'a mut self,
39//!     ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
40//!         async move {
41//!             let (protocol_proxy, server_end) =
42//!                 fidl::endpoints::create_proxy();
43//!             let protocol_factory = connect_to_protocol::<ProtocolFactoryMarker>()
44//!                 .context("Failed to connect to test.protocol.ProtocolFactory")?;
45//!
46//!             protocol_factory
47//!                 .create_protocol(server_end)
48//!                 .await?
49//!                 .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
50//!
51//!             Ok(protocol_proxy)
52//!         }
53//!         .boxed()
54//!     }
55//!
56//!     fn send_message<'a>(
57//!         &'a mut self,
58//!         protocol: &'a Self::Protocol,
59//!         _msg: (),
60//!     ) -> BoxFuture<'a, Result<(), Self::SendError>> {
61//!         async move {
62//!             protocol.do_action().await?.map_err(|e| format_err!("Failed to do action: {:?}", e))?;
63//!             Ok(())
64//!         }
65//!         .boxed()
66//!     }
67//! }
68//! ```
69//!
70//! Then all you would have to do to connect to the service is:
71//!
72//! ```rust
73//! let connector = ProtocolConnector::new(ProtocolConnectedProtocol);
74//! let (sender, future) = connector.serve_and_log_errors();
75//! let future = Task::spawn(future);
76//! // Use sender to send messages to the protocol
77//! ```
78
79use fuchsia_async::{self as fasync, DurationExt};
80
81use futures::channel::mpsc;
82use futures::future::BoxFuture;
83use futures::{Future, StreamExt};
84use log::error;
85use std::sync::atomic::{AtomicBool, Ordering};
86use std::sync::Arc;
87
88/// A trait for implementing connecting to and sending messages to a FIDL protocol.
89pub trait ConnectedProtocol {
90    /// The protocol that will be connected to.
91    type Protocol: fidl::endpoints::Proxy;
92
93    /// An error type returned for connection failures.
94    type ConnectError: std::fmt::Display;
95
96    /// The message type that will be forwarded to the `Protocol`.
97    type Message;
98
99    /// An error type returned for message send failures.
100    type SendError: std::fmt::Display;
101
102    /// Connects to the protocol represented by `Protocol`.
103    ///
104    /// If this is a two-step process as in the case of the ServiceHub pattern,
105    /// both steps should be performed in this function.
106    fn get_protocol<'a>(&'a mut self) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>>;
107
108    /// Sends a message to the underlying `Protocol`.
109    ///
110    /// The protocol object should be assumed to be connected.
111    fn send_message<'a>(
112        &'a mut self,
113        protocol: &'a Self::Protocol,
114        msg: Self::Message,
115    ) -> BoxFuture<'a, Result<(), Self::SendError>>;
116}
117
118/// A ProtocolSender wraps around an `mpsc::Sender` object that is used to send
119/// messages to a running ProtocolConnector instance.
120#[derive(Clone, Debug)]
121pub struct ProtocolSender<Msg> {
122    sender: mpsc::Sender<Msg>,
123    is_blocked: Arc<AtomicBool>,
124}
125
126/// Returned by ProtocolSender::send to notify the caller about the state of the underlying mpsc::channel.
127/// None of these status codes should be considered an error state, they are purely informational.
128#[derive(Debug, Copy, Clone, PartialEq, Eq)]
129pub enum ProtocolSenderStatus {
130    /// channel is accepting new messages.
131    Healthy,
132
133    /// channel has rejected its first message.
134    BackoffStarts,
135
136    /// channel is not accepting new messages.
137    InBackoff,
138
139    /// channel has begun accepting messages again.
140    BackoffEnds,
141}
142
143impl<Msg> ProtocolSender<Msg> {
144    /// Create a new ProtocolSender which will use `sender` to send messages.
145    pub fn new(sender: mpsc::Sender<Msg>) -> Self {
146        Self { sender, is_blocked: Arc::new(AtomicBool::new(false)) }
147    }
148
149    /// Send a message to the underlying channel.
150    ///
151    /// When the sender enters or exits a backoff state, it will log an error,
152    /// but no other feedback will be provided to the caller.
153    pub fn send(&mut self, message: Msg) -> ProtocolSenderStatus {
154        if self.sender.try_send(message).is_err() {
155            let was_blocked =
156                self.is_blocked.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
157            if let Ok(false) = was_blocked {
158                ProtocolSenderStatus::BackoffStarts
159            } else {
160                ProtocolSenderStatus::InBackoff
161            }
162        } else {
163            let was_blocked =
164                self.is_blocked.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
165            if let Ok(true) = was_blocked {
166                ProtocolSenderStatus::BackoffEnds
167            } else {
168                ProtocolSenderStatus::Healthy
169            }
170        }
171    }
172}
173
174struct ExponentialBackoff {
175    initial: zx::MonotonicDuration,
176    current: zx::MonotonicDuration,
177    factor: f64,
178}
179
180impl ExponentialBackoff {
181    fn new(initial: zx::MonotonicDuration, factor: f64) -> Self {
182        Self { initial, current: initial, factor }
183    }
184
185    fn next_timer(&mut self) -> fasync::Timer {
186        let timer = fasync::Timer::new(self.current.after_now());
187        self.current = zx::MonotonicDuration::from_nanos(
188            (self.current.into_nanos() as f64 * self.factor) as i64,
189        );
190        timer
191    }
192
193    fn reset(&mut self) {
194        self.current = self.initial;
195    }
196}
197
198/// Errors encountered while connecting to or sending messages to the ConnectedProtocol implementation.
199#[derive(Debug, PartialEq, Eq)]
200pub enum ProtocolConnectorError<ConnectError, ProtocolError> {
201    /// Connecting to the protocol failed for some reason.
202    ConnectFailed(ConnectError),
203
204    /// Connection to the protocol was dropped. A reconnect will be triggered.
205    ConnectionLost,
206
207    /// The protocol returned an error while sending a message.
208    ProtocolError(ProtocolError),
209}
210/// ProtocolConnector contains the logic to use a `ConnectedProtocol` to connect
211/// to and forward messages to a protocol.
212pub struct ProtocolConnector<CP: ConnectedProtocol> {
213    /// The size of the `mpsc::channel` to use when sending event objects from the main thread to the worker thread.
214    pub buffer_size: usize,
215    protocol: CP,
216}
217
218impl<CP: ConnectedProtocol> ProtocolConnector<CP> {
219    /// Construct a ProtocolConnector with the default `buffer_size` (10)
220    pub fn new(protocol: CP) -> Self {
221        Self::new_with_buffer_size(protocol, 10)
222    }
223
224    /// Construct a ProtocolConnector with a specified `buffer_size`
225    pub fn new_with_buffer_size(protocol: CP, buffer_size: usize) -> Self {
226        Self { buffer_size, protocol }
227    }
228
229    /// serve_and_log_errors creates both a ProtocolSender and a future that can
230    /// be used to send messages to the underlying protocol. All errors from the
231    /// underlying protocol will be logged.
232    pub fn serve_and_log_errors(self) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
233        let protocol = <<<CP as ConnectedProtocol>::Protocol as fidl::endpoints::Proxy>::Protocol as fidl::endpoints::ProtocolMarker>::DEBUG_NAME;
234        let mut log_error = log_first_n_factory(30, move |e| error!(protocol:%; "{e}"));
235        self.serve(move |e| match e {
236            ProtocolConnectorError::ConnectFailed(e) => {
237                log_error(format!("Error obtaining a connection to the protocol: {}", e))
238            }
239            ProtocolConnectorError::ConnectionLost => {
240                log_error("Protocol disconnected, starting reconnect.".into())
241            }
242            ProtocolConnectorError::ProtocolError(e) => {
243                log_error(format!("Protocol returned an error: {}", e))
244            }
245        })
246    }
247
248    /// serve creates both a ProtocolSender and a future that can be used to send
249    /// messages to the underlying protocol.
250    pub fn serve<ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>)>(
251        self,
252        h: ErrHandler,
253    ) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
254        let (sender, receiver) = mpsc::channel(self.buffer_size);
255        let sender = ProtocolSender::new(sender);
256        (sender, self.send_events(receiver, h))
257    }
258
259    async fn send_events<
260        ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>),
261    >(
262        mut self,
263        mut receiver: mpsc::Receiver<<CP as ConnectedProtocol>::Message>,
264        mut h: ErrHandler,
265    ) {
266        let mut backoff = ExponentialBackoff::new(zx::MonotonicDuration::from_millis(100), 2.0);
267        loop {
268            let protocol = match self.protocol.get_protocol().await {
269                Ok(protocol) => protocol,
270                Err(e) => {
271                    h(ProtocolConnectorError::ConnectFailed(e));
272                    backoff.next_timer().await;
273                    continue;
274                }
275            };
276
277            'receiving: loop {
278                match receiver.next().await {
279                    Some(message) => {
280                        let resp = self.protocol.send_message(&protocol, message).await;
281                        match resp {
282                            Ok(_) => {
283                                backoff.reset();
284                                continue;
285                            }
286                            Err(e) => {
287                                if fidl::endpoints::Proxy::is_closed(&protocol) {
288                                    h(ProtocolConnectorError::ConnectionLost);
289                                    break 'receiving;
290                                } else {
291                                    h(ProtocolConnectorError::ProtocolError(e));
292                                }
293                            }
294                        }
295                    }
296                    None => return,
297                }
298            }
299
300            backoff.next_timer().await;
301        }
302    }
303}
304
305fn log_first_n_factory(n: u64, mut log_fn: impl FnMut(String)) -> impl FnMut(String) {
306    let mut count = 0;
307    move |message| {
308        if count < n {
309            count += 1;
310            log_fn(message);
311        }
312    }
313}
314
315#[cfg(test)]
316mod test {
317    use super::*;
318    use anyhow::{format_err, Context};
319    use fidl_test_protocol_connector::{
320        ProtocolFactoryMarker, ProtocolFactoryRequest, ProtocolFactoryRequestStream, ProtocolProxy,
321        ProtocolRequest, ProtocolRequestStream,
322    };
323    use fuchsia_async as fasync;
324    use fuchsia_component::server as fserver;
325    use fuchsia_component_test::{
326        Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
327    };
328    use futures::channel::mpsc::Sender;
329    use futures::{FutureExt, TryStreamExt};
330    use std::sync::atomic::AtomicU8;
331
332    struct ProtocolConnectedProtocol(RealmInstance, Sender<()>);
333    impl ConnectedProtocol for ProtocolConnectedProtocol {
334        type Protocol = ProtocolProxy;
335        type ConnectError = anyhow::Error;
336        type Message = ();
337        type SendError = anyhow::Error;
338
339        fn get_protocol<'a>(
340            &'a mut self,
341        ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
342            async move {
343                let (protocol_proxy, server_end) = fidl::endpoints::create_proxy();
344                let protocol_factory = self
345                    .0
346                    .root
347                    .connect_to_protocol_at_exposed_dir::<ProtocolFactoryMarker>()
348                    .context("Connecting to test.protocol.ProtocolFactory failed")?;
349
350                protocol_factory
351                    .create_protocol(server_end)
352                    .await?
353                    .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
354
355                Ok(protocol_proxy)
356            }
357            .boxed()
358        }
359
360        fn send_message<'a>(
361            &'a mut self,
362            protocol: &'a Self::Protocol,
363            _msg: (),
364        ) -> BoxFuture<'a, Result<(), Self::SendError>> {
365            async move {
366                protocol
367                    .do_action()
368                    .await?
369                    .map_err(|e| format_err!("Failed to do action: {:?}", e))?;
370                self.1.try_send(())?;
371                Ok(())
372            }
373            .boxed()
374        }
375    }
376
377    async fn protocol_mock(
378        stream: ProtocolRequestStream,
379        calls_made: Arc<AtomicU8>,
380        close_after: Option<Arc<AtomicU8>>,
381    ) -> Result<(), anyhow::Error> {
382        stream
383            .map(|result| result.context("failed request"))
384            .try_for_each(|request| async {
385                let calls_made = calls_made.clone();
386                let close_after = close_after.clone();
387                match request {
388                    ProtocolRequest::DoAction { responder } => {
389                        calls_made.fetch_add(1, Ordering::SeqCst);
390                        responder.send(Ok(()))?;
391                    }
392                }
393
394                if let Some(ca) = &close_after {
395                    if ca.fetch_sub(1, Ordering::SeqCst) == 1 {
396                        return Err(format_err!("close_after triggered"));
397                    }
398                }
399                Ok(())
400            })
401            .await
402    }
403
404    async fn protocol_factory_mock(
405        handles: LocalComponentHandles,
406        calls_made: Arc<AtomicU8>,
407        close_after: Option<u8>,
408    ) -> Result<(), anyhow::Error> {
409        let mut fs = fserver::ServiceFs::new();
410        let mut tasks = vec![];
411
412        fs.dir("svc").add_fidl_service(move |mut stream: ProtocolFactoryRequestStream| {
413            let calls_made = calls_made.clone();
414            tasks.push(fasync::Task::local(async move {
415                while let Some(ProtocolFactoryRequest::CreateProtocol { protocol, responder }) =
416                    stream.try_next().await.expect("ProtocolFactoryRequestStream yielded an Err(_)")
417                {
418                    let close_after = close_after.map(|ca| Arc::new(AtomicU8::new(ca)));
419                    responder.send(Ok(())).expect("Replying to CreateProtocol caller failed");
420                    let _ = protocol_mock(protocol.into_stream(), calls_made.clone(), close_after)
421                        .await;
422                }
423            }));
424        });
425
426        fs.serve_connection(handles.outgoing_dir)?;
427        fs.collect::<()>().await;
428
429        Ok(())
430    }
431
432    async fn setup_realm(
433        calls_made: Arc<AtomicU8>,
434        close_after: Option<u8>,
435    ) -> Result<RealmInstance, anyhow::Error> {
436        let builder = RealmBuilder::new().await?;
437
438        let protocol_factory_server = builder
439            .add_local_child(
440                "protocol_factory",
441                move |handles: LocalComponentHandles| {
442                    Box::pin(protocol_factory_mock(handles, calls_made.clone(), close_after))
443                },
444                ChildOptions::new(),
445            )
446            .await?;
447
448        builder
449            .add_route(
450                Route::new()
451                    .capability(Capability::protocol_by_name(
452                        "test.protocol.connector.ProtocolFactory",
453                    ))
454                    .from(&protocol_factory_server)
455                    .to(Ref::parent()),
456            )
457            .await?;
458
459        Ok(builder.build().await?)
460    }
461
462    #[fuchsia::test(logging_tags = ["test_protocol_connector"])]
463    async fn test_protocol_connector() -> Result<(), anyhow::Error> {
464        let calls_made = Arc::new(AtomicU8::new(0));
465        let realm = setup_realm(calls_made.clone(), None).await?;
466        let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
467        let connector = ProtocolConnectedProtocol(realm, log_received_sender);
468
469        let error_count = Arc::new(AtomicU8::new(0));
470        let svc = ProtocolConnector::new(connector);
471        let (mut sender, fut) = svc.serve({
472            let count = error_count.clone();
473            move |e| {
474                error!("Encountered unexpected error: {:?}", e);
475                count.fetch_add(1, Ordering::SeqCst);
476            }
477        });
478
479        let _server = fasync::Task::local(fut);
480
481        for _ in 0..10 {
482            assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
483            log_received_receiver.next().await;
484        }
485
486        assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
487        assert_eq!(error_count.fetch_add(0, Ordering::SeqCst), 0);
488
489        Ok(())
490    }
491
492    #[fuchsia::test(logging_tags = ["test_protocol_reconnnect"])]
493    async fn test_protocol_reconnect() -> Result<(), anyhow::Error> {
494        let calls_made = Arc::new(AtomicU8::new(0));
495
496        // Simulate the protocol closing after each successful call.
497        let realm = setup_realm(calls_made.clone(), Some(1)).await?;
498        let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
499        let connector = ProtocolConnectedProtocol(realm, log_received_sender);
500
501        let svc = ProtocolConnector::new(connector);
502        let (mut err_send, mut err_rcv) = mpsc::channel(1);
503        let (mut sender, fut) = svc.serve(move |e| {
504            err_send.try_send(e).expect("Could not log error");
505        });
506
507        let _server = fasync::Task::local(fut);
508
509        for _ in 0..10 {
510            // This first send will successfully call the underlying protocol.
511            assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
512            log_received_receiver.next().await;
513
514            // The second send will not, because the protocol has shut down.
515            assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
516            match err_rcv.next().await.expect("Expected err") {
517                ProtocolConnectorError::ConnectionLost => {}
518                _ => {
519                    assert!(false, "saw unexpected error type");
520                }
521            }
522        }
523
524        assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
525
526        Ok(())
527    }
528}