fidl_contrib/
protocol_connector.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides libs for connecting to and interacting with a FIDL protocol.
6//!
7//! If you have a fidl protocol like this:
8//!
9//! ```fidl
10//! type Error = strict enum : int32 {
11//!     PERMANENT = 1;
12//!     TRANSIENT = 2;
13//! };
14//!
15//! @discoverable
16//! protocol ProtocolFactory {
17//!     CreateProtocol(resource struct {
18//!         protocol server_end:Protocol;
19//!     }) -> () error Error;
20//! };
21//!
22//! protocol Protocol {
23//!     DoAction() -> () error Error;
24//! };
25//! ```
26//!
27//! Then you could implement ConnectedProtocol as follows:
28//!
29//! ```rust
30//! struct ProtocolConnectedProtocol;
31//! impl ConnectedProtocol for ProtocolConnectedProtocol {
32//!     type Protocol = ProtocolProxy;
33//!     type ConnectError = anyhow::Error;
34//!     type Message = ();
35//!     type SendError = anyhow::Error;
36//!
37//!     fn get_protocol<'a>(
38//!         &'a mut self,
39//!     ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
40//!         async move {
41//!             let (protocol_proxy, server_end) =
42//!                 fidl::endpoints::create_proxy();
43//!             let protocol_factory = connect_to_protocol::<ProtocolFactoryMarker>()
44//!                 .context("Failed to connect to test.protocol.ProtocolFactory")?;
45//!
46//!             protocol_factory
47//!                 .create_protocol(server_end)
48//!                 .await?
49//!                 .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
50//!
51//!             Ok(protocol_proxy)
52//!         }
53//!         .boxed()
54//!     }
55//!
56//!     fn send_message<'a>(
57//!         &'a mut self,
58//!         protocol: &'a Self::Protocol,
59//!         _msg: (),
60//!     ) -> BoxFuture<'a, Result<(), Self::SendError>> {
61//!         async move {
62//!             protocol.do_action().await?.map_err(|e| format_err!("Failed to do action: {:?}", e))?;
63//!             Ok(())
64//!         }
65//!         .boxed()
66//!     }
67//! }
68//! ```
69//!
70//! Then all you would have to do to connect to the service is:
71//!
72//! ```rust
73//! let connector = ProtocolConnector::new(ProtocolConnectedProtocol);
74//! let (sender, future) = connector.serve_and_log_errors();
75//! let future = Task::spawn(future);
76//! // Use sender to send messages to the protocol
77//! ```
78
79use fuchsia_async::{self as fasync, DurationExt};
80
81use futures::channel::mpsc;
82use futures::future::BoxFuture;
83use futures::{Future, StreamExt};
84use log::error;
85use std::sync::Arc;
86use std::sync::atomic::{AtomicBool, Ordering};
87
88/// A trait for implementing connecting to and sending messages to a FIDL protocol.
89pub trait ConnectedProtocol {
90    /// The protocol that will be connected to.
91    type Protocol: fidl::endpoints::Proxy;
92
93    /// An error type returned for connection failures.
94    type ConnectError: std::fmt::Display + 'static;
95
96    /// The message type that will be forwarded to the `Protocol`.
97    type Message;
98
99    /// An error type returned for message send failures.
100    type SendError: std::fmt::Display;
101
102    /// Connects to the protocol represented by `Protocol`.
103    ///
104    /// If this is a two-step process as in the case of the ServiceHub pattern,
105    /// both steps should be performed in this function.
106    fn get_protocol<'a>(&'a mut self) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>>;
107
108    /// Sends a message to the underlying `Protocol`.
109    ///
110    /// The protocol object should be assumed to be connected.
111    fn send_message<'a>(
112        &'a mut self,
113        protocol: &'a Self::Protocol,
114        msg: Self::Message,
115    ) -> BoxFuture<'a, Result<(), Self::SendError>>;
116
117    /// Determines if a connection error is retryable. The default is to always retry.
118    fn should_retry_on_connect_error(&self, _error: &Self::ConnectError) -> bool {
119        true
120    }
121}
122
123/// A ProtocolSender wraps around an `mpsc::Sender` object that is used to send
124/// messages to a running ProtocolConnector instance.
125#[derive(Clone, Debug)]
126pub struct ProtocolSender<Msg> {
127    sender: mpsc::Sender<Msg>,
128    is_blocked: Arc<AtomicBool>,
129}
130
131/// Returned by ProtocolSender::send to notify the caller about the state of the underlying mpsc::channel.
132/// None of these status codes should be considered an error state, they are purely informational.
133#[derive(Debug, Copy, Clone, PartialEq, Eq)]
134pub enum ProtocolSenderStatus {
135    /// channel is accepting new messages.
136    Healthy,
137
138    /// channel has rejected its first message.
139    BackoffStarts,
140
141    /// channel is not accepting new messages.
142    InBackoff,
143
144    /// channel has begun accepting messages again.
145    BackoffEnds,
146}
147
148impl<Msg> ProtocolSender<Msg> {
149    /// Create a new ProtocolSender which will use `sender` to send messages.
150    pub fn new(sender: mpsc::Sender<Msg>) -> Self {
151        Self { sender, is_blocked: Arc::new(AtomicBool::new(false)) }
152    }
153
154    /// Send a message to the underlying channel.
155    ///
156    /// When the sender enters or exits a backoff state, it will log an error,
157    /// but no other feedback will be provided to the caller.
158    pub fn send(&mut self, message: Msg) -> ProtocolSenderStatus {
159        if self.sender.try_send(message).is_err() {
160            let was_blocked =
161                self.is_blocked.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
162            if let Ok(false) = was_blocked {
163                ProtocolSenderStatus::BackoffStarts
164            } else {
165                ProtocolSenderStatus::InBackoff
166            }
167        } else {
168            let was_blocked =
169                self.is_blocked.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
170            if let Ok(true) = was_blocked {
171                ProtocolSenderStatus::BackoffEnds
172            } else {
173                ProtocolSenderStatus::Healthy
174            }
175        }
176    }
177}
178
179struct ExponentialBackoff {
180    initial: zx::MonotonicDuration,
181    current: zx::MonotonicDuration,
182    factor: f64,
183}
184
185impl ExponentialBackoff {
186    fn new(initial: zx::MonotonicDuration, factor: f64) -> Self {
187        Self { initial, current: initial, factor }
188    }
189
190    fn next_timer(&mut self) -> fasync::Timer {
191        let timer = fasync::Timer::new(self.current.after_now());
192        self.current = zx::MonotonicDuration::from_nanos(
193            (self.current.into_nanos() as f64 * self.factor) as i64,
194        );
195        timer
196    }
197
198    fn reset(&mut self) {
199        self.current = self.initial;
200    }
201}
202
203/// Errors encountered while connecting to or sending messages to the ConnectedProtocol implementation.
204#[derive(Debug, PartialEq, Eq)]
205pub enum ProtocolConnectorError<ConnectError, ProtocolError> {
206    /// Connecting to the protocol failed for some reason.
207    ConnectFailed(ConnectError),
208
209    /// Connection to the protocol was dropped. A reconnect will be triggered.
210    ConnectionLost,
211
212    /// The protocol returned an error while sending a message.
213    ProtocolError(ProtocolError),
214}
215/// ProtocolConnector contains the logic to use a `ConnectedProtocol` to connect
216/// to and forward messages to a protocol.
217pub struct ProtocolConnector<CP: ConnectedProtocol> {
218    /// The size of the `mpsc::channel` to use when sending event objects from the main thread to the worker thread.
219    pub buffer_size: usize,
220    protocol: CP,
221}
222
223impl<CP: ConnectedProtocol> ProtocolConnector<CP> {
224    /// Construct a ProtocolConnector with the default `buffer_size` (10)
225    pub fn new(protocol: CP) -> Self {
226        Self::new_with_buffer_size(protocol, 10)
227    }
228
229    /// Construct a ProtocolConnector with a specified `buffer_size`
230    pub fn new_with_buffer_size(protocol: CP, buffer_size: usize) -> Self {
231        Self { buffer_size, protocol }
232    }
233
234    /// serve_and_log_errors creates both a ProtocolSender and a future that can
235    /// be used to send messages to the underlying protocol. All errors from the
236    /// underlying protocol will be logged.
237    pub fn serve_and_log_errors(self) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
238        let protocol = <<<CP as ConnectedProtocol>::Protocol as fidl::endpoints::Proxy>::Protocol as fidl::endpoints::ProtocolMarker>::DEBUG_NAME;
239        let mut log_error = log_first_n_factory(30, move |e| error!(protocol:%; "{e}"));
240        self.serve(move |e| match e {
241            ProtocolConnectorError::ConnectFailed(e) => {
242                log_error(format!("Error obtaining a connection to the protocol: {}", e))
243            }
244            ProtocolConnectorError::ConnectionLost => {
245                log_error("Protocol disconnected, starting reconnect.".into())
246            }
247            ProtocolConnectorError::ProtocolError(e) => {
248                log_error(format!("Protocol returned an error: {}", e))
249            }
250        })
251    }
252
253    /// serve creates both a ProtocolSender and a future that can be used to send
254    /// messages to the underlying protocol.
255    pub fn serve<ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>)>(
256        self,
257        h: ErrHandler,
258    ) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
259        let (sender, receiver) = mpsc::channel(self.buffer_size);
260        let sender = ProtocolSender::new(sender);
261        (sender, self.send_events(receiver, h))
262    }
263
264    async fn send_events<
265        ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>),
266    >(
267        mut self,
268        mut receiver: mpsc::Receiver<<CP as ConnectedProtocol>::Message>,
269        mut h: ErrHandler,
270    ) {
271        let mut backoff = ExponentialBackoff::new(zx::MonotonicDuration::from_millis(100), 2.0);
272        loop {
273            let protocol = match self.protocol.get_protocol().await {
274                Ok(protocol) => protocol,
275                Err(e) => {
276                    if !self.protocol.should_retry_on_connect_error(&e) {
277                        error!("Stopping retries as requested: {e}");
278                        return;
279                    }
280
281                    h(ProtocolConnectorError::ConnectFailed(e));
282                    backoff.next_timer().await;
283                    continue;
284                }
285            };
286
287            'receiving: loop {
288                match receiver.next().await {
289                    Some(message) => {
290                        let resp = self.protocol.send_message(&protocol, message).await;
291                        match resp {
292                            Ok(_) => {
293                                backoff.reset();
294                                continue;
295                            }
296                            Err(e) => {
297                                if fidl::endpoints::Proxy::is_closed(&protocol) {
298                                    h(ProtocolConnectorError::ConnectionLost);
299                                    break 'receiving;
300                                } else {
301                                    h(ProtocolConnectorError::ProtocolError(e));
302                                }
303                            }
304                        }
305                    }
306                    None => return,
307                }
308            }
309
310            backoff.next_timer().await;
311        }
312    }
313}
314
315fn log_first_n_factory(n: u64, mut log_fn: impl FnMut(String)) -> impl FnMut(String) {
316    let mut count = 0;
317    move |message| {
318        if count < n {
319            count += 1;
320            log_fn(message);
321        }
322    }
323}
324
325#[cfg(test)]
326mod test {
327    use super::*;
328    use anyhow::{Context, format_err};
329    use fidl_test_protocol_connector::{
330        ProtocolFactoryProxy, ProtocolFactoryRequest, ProtocolFactoryRequestStream, ProtocolProxy,
331        ProtocolRequest, ProtocolRequestStream,
332    };
333    use fuchsia_async as fasync;
334    use fuchsia_component::server as fserver;
335    use fuchsia_component_test::{
336        Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
337    };
338    use futures::channel::mpsc::Sender;
339    use futures::{FutureExt, TryStreamExt};
340    use std::sync::atomic::AtomicU8;
341
342    struct ProtocolConnectedProtocol(RealmInstance, Sender<()>);
343    impl ConnectedProtocol for ProtocolConnectedProtocol {
344        type Protocol = ProtocolProxy;
345        type ConnectError = anyhow::Error;
346        type Message = ();
347        type SendError = anyhow::Error;
348
349        fn get_protocol<'a>(
350            &'a mut self,
351        ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
352            async move {
353                let (protocol_proxy, server_end) = fidl::endpoints::create_proxy();
354                let protocol_factory: ProtocolFactoryProxy = self
355                    .0
356                    .root
357                    .connect_to_protocol_at_exposed_dir()
358                    .context("Connecting to test.protocol.ProtocolFactory failed")?;
359
360                protocol_factory
361                    .create_protocol(server_end)
362                    .await?
363                    .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
364
365                Ok(protocol_proxy)
366            }
367            .boxed()
368        }
369
370        fn send_message<'a>(
371            &'a mut self,
372            protocol: &'a Self::Protocol,
373            _msg: (),
374        ) -> BoxFuture<'a, Result<(), Self::SendError>> {
375            async move {
376                protocol
377                    .do_action()
378                    .await?
379                    .map_err(|e| format_err!("Failed to do action: {:?}", e))?;
380                self.1.try_send(())?;
381                Ok(())
382            }
383            .boxed()
384        }
385    }
386
387    async fn protocol_mock(
388        stream: ProtocolRequestStream,
389        calls_made: Arc<AtomicU8>,
390        close_after: Option<Arc<AtomicU8>>,
391    ) -> Result<(), anyhow::Error> {
392        stream
393            .map(|result| result.context("failed request"))
394            .try_for_each(|request| async {
395                let calls_made = calls_made.clone();
396                let close_after = close_after.clone();
397                match request {
398                    ProtocolRequest::DoAction { responder } => {
399                        calls_made.fetch_add(1, Ordering::SeqCst);
400                        responder.send(Ok(()))?;
401                    }
402                }
403
404                if let Some(ca) = &close_after {
405                    if ca.fetch_sub(1, Ordering::SeqCst) == 1 {
406                        return Err(format_err!("close_after triggered"));
407                    }
408                }
409                Ok(())
410            })
411            .await
412    }
413
414    async fn protocol_factory_mock(
415        handles: LocalComponentHandles,
416        calls_made: Arc<AtomicU8>,
417        close_after: Option<u8>,
418    ) -> Result<(), anyhow::Error> {
419        let mut fs = fserver::ServiceFs::new();
420        let mut tasks = vec![];
421
422        fs.dir("svc").add_fidl_service(move |mut stream: ProtocolFactoryRequestStream| {
423            let calls_made = calls_made.clone();
424            tasks.push(fasync::Task::local(async move {
425                while let Some(ProtocolFactoryRequest::CreateProtocol { protocol, responder }) =
426                    stream.try_next().await.expect("ProtocolFactoryRequestStream yielded an Err(_)")
427                {
428                    let close_after = close_after.map(|ca| Arc::new(AtomicU8::new(ca)));
429                    responder.send(Ok(())).expect("Replying to CreateProtocol caller failed");
430                    let _ = protocol_mock(protocol.into_stream(), calls_made.clone(), close_after)
431                        .await;
432                }
433            }));
434        });
435
436        fs.serve_connection(handles.outgoing_dir)?;
437        fs.collect::<()>().await;
438
439        Ok(())
440    }
441
442    async fn setup_realm(
443        calls_made: Arc<AtomicU8>,
444        close_after: Option<u8>,
445    ) -> Result<RealmInstance, anyhow::Error> {
446        let builder = RealmBuilder::new().await?;
447
448        let protocol_factory_server = builder
449            .add_local_child(
450                "protocol_factory",
451                move |handles: LocalComponentHandles| {
452                    Box::pin(protocol_factory_mock(handles, calls_made.clone(), close_after))
453                },
454                ChildOptions::new(),
455            )
456            .await?;
457
458        builder
459            .add_route(
460                Route::new()
461                    .capability(Capability::protocol_by_name(
462                        "test.protocol.connector.ProtocolFactory",
463                    ))
464                    .from(&protocol_factory_server)
465                    .to(Ref::parent()),
466            )
467            .await?;
468
469        Ok(builder.build().await?)
470    }
471
472    #[fuchsia::test(logging_tags = ["test_protocol_connector"])]
473    async fn test_protocol_connector() -> Result<(), anyhow::Error> {
474        let calls_made = Arc::new(AtomicU8::new(0));
475        let realm = setup_realm(calls_made.clone(), None).await?;
476        let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
477        let connector = ProtocolConnectedProtocol(realm, log_received_sender);
478
479        let error_count = Arc::new(AtomicU8::new(0));
480        let svc = ProtocolConnector::new(connector);
481        let (mut sender, fut) = svc.serve({
482            let count = error_count.clone();
483            move |e| {
484                error!("Encountered unexpected error: {:?}", e);
485                count.fetch_add(1, Ordering::SeqCst);
486            }
487        });
488
489        let _server = fasync::Task::local(fut);
490
491        for _ in 0..10 {
492            assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
493            log_received_receiver.next().await;
494        }
495
496        assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
497        assert_eq!(error_count.fetch_add(0, Ordering::SeqCst), 0);
498
499        Ok(())
500    }
501
502    #[fuchsia::test(logging_tags = ["test_protocol_reconnnect"])]
503    async fn test_protocol_reconnect() -> Result<(), anyhow::Error> {
504        let calls_made = Arc::new(AtomicU8::new(0));
505
506        // Simulate the protocol closing after each successful call.
507        let realm = setup_realm(calls_made.clone(), Some(1)).await?;
508        let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
509        let connector = ProtocolConnectedProtocol(realm, log_received_sender);
510
511        let svc = ProtocolConnector::new(connector);
512        let (mut err_send, mut err_rcv) = mpsc::channel(1);
513        let (mut sender, fut) = svc.serve(move |e| {
514            err_send.try_send(e).expect("Could not log error");
515        });
516
517        let _server = fasync::Task::local(fut);
518
519        for _ in 0..10 {
520            // This first send will successfully call the underlying protocol.
521            assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
522            log_received_receiver.next().await;
523
524            // The second send will not, because the protocol has shut down.
525            assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
526            match err_rcv.next().await.expect("Expected err") {
527                ProtocolConnectorError::ConnectionLost => {}
528                _ => {
529                    assert!(false, "saw unexpected error type");
530                }
531            }
532        }
533
534        assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
535
536        Ok(())
537    }
538}