1use fuchsia_async::{self as fasync, DurationExt};
80
81use futures::channel::mpsc;
82use futures::future::BoxFuture;
83use futures::{Future, StreamExt};
84use log::error;
85use std::sync::Arc;
86use std::sync::atomic::{AtomicBool, Ordering};
87
88pub trait ConnectedProtocol {
90 type Protocol: fidl::endpoints::Proxy;
92
93 type ConnectError: std::fmt::Display + 'static;
95
96 type Message;
98
99 type SendError: std::fmt::Display;
101
102 fn get_protocol<'a>(&'a mut self) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>>;
107
108 fn send_message<'a>(
112 &'a mut self,
113 protocol: &'a Self::Protocol,
114 msg: Self::Message,
115 ) -> BoxFuture<'a, Result<(), Self::SendError>>;
116
117 fn should_retry_on_connect_error(&self, _error: &Self::ConnectError) -> bool {
119 true
120 }
121}
122
123#[derive(Clone, Debug)]
126pub struct ProtocolSender<Msg> {
127 sender: mpsc::Sender<Msg>,
128 is_blocked: Arc<AtomicBool>,
129}
130
131#[derive(Debug, Copy, Clone, PartialEq, Eq)]
134pub enum ProtocolSenderStatus {
135 Healthy,
137
138 BackoffStarts,
140
141 InBackoff,
143
144 BackoffEnds,
146}
147
148impl<Msg> ProtocolSender<Msg> {
149 pub fn new(sender: mpsc::Sender<Msg>) -> Self {
151 Self { sender, is_blocked: Arc::new(AtomicBool::new(false)) }
152 }
153
154 pub fn send(&mut self, message: Msg) -> ProtocolSenderStatus {
159 if self.sender.try_send(message).is_err() {
160 let was_blocked =
161 self.is_blocked.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
162 if let Ok(false) = was_blocked {
163 ProtocolSenderStatus::BackoffStarts
164 } else {
165 ProtocolSenderStatus::InBackoff
166 }
167 } else {
168 let was_blocked =
169 self.is_blocked.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
170 if let Ok(true) = was_blocked {
171 ProtocolSenderStatus::BackoffEnds
172 } else {
173 ProtocolSenderStatus::Healthy
174 }
175 }
176 }
177}
178
179struct ExponentialBackoff {
180 initial: zx::MonotonicDuration,
181 current: zx::MonotonicDuration,
182 factor: f64,
183}
184
185impl ExponentialBackoff {
186 fn new(initial: zx::MonotonicDuration, factor: f64) -> Self {
187 Self { initial, current: initial, factor }
188 }
189
190 fn next_timer(&mut self) -> fasync::Timer {
191 let timer = fasync::Timer::new(self.current.after_now());
192 self.current = zx::MonotonicDuration::from_nanos(
193 (self.current.into_nanos() as f64 * self.factor) as i64,
194 );
195 timer
196 }
197
198 fn reset(&mut self) {
199 self.current = self.initial;
200 }
201}
202
203#[derive(Debug, PartialEq, Eq)]
205pub enum ProtocolConnectorError<ConnectError, ProtocolError> {
206 ConnectFailed(ConnectError),
208
209 ConnectionLost,
211
212 ProtocolError(ProtocolError),
214}
215pub struct ProtocolConnector<CP: ConnectedProtocol> {
218 pub buffer_size: usize,
220 protocol: CP,
221}
222
223impl<CP: ConnectedProtocol> ProtocolConnector<CP> {
224 pub fn new(protocol: CP) -> Self {
226 Self::new_with_buffer_size(protocol, 10)
227 }
228
229 pub fn new_with_buffer_size(protocol: CP, buffer_size: usize) -> Self {
231 Self { buffer_size, protocol }
232 }
233
234 pub fn serve_and_log_errors(self) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
238 let protocol = <<<CP as ConnectedProtocol>::Protocol as fidl::endpoints::Proxy>::Protocol as fidl::endpoints::ProtocolMarker>::DEBUG_NAME;
239 let mut log_error = log_first_n_factory(30, move |e| error!(protocol:%; "{e}"));
240 self.serve(move |e| match e {
241 ProtocolConnectorError::ConnectFailed(e) => {
242 log_error(format!("Error obtaining a connection to the protocol: {}", e))
243 }
244 ProtocolConnectorError::ConnectionLost => {
245 log_error("Protocol disconnected, starting reconnect.".into())
246 }
247 ProtocolConnectorError::ProtocolError(e) => {
248 log_error(format!("Protocol returned an error: {}", e))
249 }
250 })
251 }
252
253 pub fn serve<ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>)>(
256 self,
257 h: ErrHandler,
258 ) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
259 let (sender, receiver) = mpsc::channel(self.buffer_size);
260 let sender = ProtocolSender::new(sender);
261 (sender, self.send_events(receiver, h))
262 }
263
264 async fn send_events<
265 ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>),
266 >(
267 mut self,
268 mut receiver: mpsc::Receiver<<CP as ConnectedProtocol>::Message>,
269 mut h: ErrHandler,
270 ) {
271 let mut backoff = ExponentialBackoff::new(zx::MonotonicDuration::from_millis(100), 2.0);
272 loop {
273 let protocol = match self.protocol.get_protocol().await {
274 Ok(protocol) => protocol,
275 Err(e) => {
276 if !self.protocol.should_retry_on_connect_error(&e) {
277 error!("Stopping retries as requested: {e}");
278 return;
279 }
280
281 h(ProtocolConnectorError::ConnectFailed(e));
282 backoff.next_timer().await;
283 continue;
284 }
285 };
286
287 'receiving: loop {
288 match receiver.next().await {
289 Some(message) => {
290 let resp = self.protocol.send_message(&protocol, message).await;
291 match resp {
292 Ok(_) => {
293 backoff.reset();
294 continue;
295 }
296 Err(e) => {
297 if fidl::endpoints::Proxy::is_closed(&protocol) {
298 h(ProtocolConnectorError::ConnectionLost);
299 break 'receiving;
300 } else {
301 h(ProtocolConnectorError::ProtocolError(e));
302 }
303 }
304 }
305 }
306 None => return,
307 }
308 }
309
310 backoff.next_timer().await;
311 }
312 }
313}
314
315fn log_first_n_factory(n: u64, mut log_fn: impl FnMut(String)) -> impl FnMut(String) {
316 let mut count = 0;
317 move |message| {
318 if count < n {
319 count += 1;
320 log_fn(message);
321 }
322 }
323}
324
325#[cfg(test)]
326mod test {
327 use super::*;
328 use anyhow::{Context, format_err};
329 use fidl_test_protocol_connector::{
330 ProtocolFactoryProxy, ProtocolFactoryRequest, ProtocolFactoryRequestStream, ProtocolProxy,
331 ProtocolRequest, ProtocolRequestStream,
332 };
333 use fuchsia_async as fasync;
334 use fuchsia_component::server as fserver;
335 use fuchsia_component_test::{
336 Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
337 };
338 use futures::channel::mpsc::Sender;
339 use futures::{FutureExt, TryStreamExt};
340 use std::sync::atomic::AtomicU8;
341
342 struct ProtocolConnectedProtocol(RealmInstance, Sender<()>);
343 impl ConnectedProtocol for ProtocolConnectedProtocol {
344 type Protocol = ProtocolProxy;
345 type ConnectError = anyhow::Error;
346 type Message = ();
347 type SendError = anyhow::Error;
348
349 fn get_protocol<'a>(
350 &'a mut self,
351 ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
352 async move {
353 let (protocol_proxy, server_end) = fidl::endpoints::create_proxy();
354 let protocol_factory: ProtocolFactoryProxy = self
355 .0
356 .root
357 .connect_to_protocol_at_exposed_dir()
358 .context("Connecting to test.protocol.ProtocolFactory failed")?;
359
360 protocol_factory
361 .create_protocol(server_end)
362 .await?
363 .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
364
365 Ok(protocol_proxy)
366 }
367 .boxed()
368 }
369
370 fn send_message<'a>(
371 &'a mut self,
372 protocol: &'a Self::Protocol,
373 _msg: (),
374 ) -> BoxFuture<'a, Result<(), Self::SendError>> {
375 async move {
376 protocol
377 .do_action()
378 .await?
379 .map_err(|e| format_err!("Failed to do action: {:?}", e))?;
380 self.1.try_send(())?;
381 Ok(())
382 }
383 .boxed()
384 }
385 }
386
387 async fn protocol_mock(
388 stream: ProtocolRequestStream,
389 calls_made: Arc<AtomicU8>,
390 close_after: Option<Arc<AtomicU8>>,
391 ) -> Result<(), anyhow::Error> {
392 stream
393 .map(|result| result.context("failed request"))
394 .try_for_each(|request| async {
395 let calls_made = calls_made.clone();
396 let close_after = close_after.clone();
397 match request {
398 ProtocolRequest::DoAction { responder } => {
399 calls_made.fetch_add(1, Ordering::SeqCst);
400 responder.send(Ok(()))?;
401 }
402 }
403
404 if let Some(ca) = &close_after {
405 if ca.fetch_sub(1, Ordering::SeqCst) == 1 {
406 return Err(format_err!("close_after triggered"));
407 }
408 }
409 Ok(())
410 })
411 .await
412 }
413
414 async fn protocol_factory_mock(
415 handles: LocalComponentHandles,
416 calls_made: Arc<AtomicU8>,
417 close_after: Option<u8>,
418 ) -> Result<(), anyhow::Error> {
419 let mut fs = fserver::ServiceFs::new();
420 let mut tasks = vec![];
421
422 fs.dir("svc").add_fidl_service(move |mut stream: ProtocolFactoryRequestStream| {
423 let calls_made = calls_made.clone();
424 tasks.push(fasync::Task::local(async move {
425 while let Some(ProtocolFactoryRequest::CreateProtocol { protocol, responder }) =
426 stream.try_next().await.expect("ProtocolFactoryRequestStream yielded an Err(_)")
427 {
428 let close_after = close_after.map(|ca| Arc::new(AtomicU8::new(ca)));
429 responder.send(Ok(())).expect("Replying to CreateProtocol caller failed");
430 let _ = protocol_mock(protocol.into_stream(), calls_made.clone(), close_after)
431 .await;
432 }
433 }));
434 });
435
436 fs.serve_connection(handles.outgoing_dir)?;
437 fs.collect::<()>().await;
438
439 Ok(())
440 }
441
442 async fn setup_realm(
443 calls_made: Arc<AtomicU8>,
444 close_after: Option<u8>,
445 ) -> Result<RealmInstance, anyhow::Error> {
446 let builder = RealmBuilder::new().await?;
447
448 let protocol_factory_server = builder
449 .add_local_child(
450 "protocol_factory",
451 move |handles: LocalComponentHandles| {
452 Box::pin(protocol_factory_mock(handles, calls_made.clone(), close_after))
453 },
454 ChildOptions::new(),
455 )
456 .await?;
457
458 builder
459 .add_route(
460 Route::new()
461 .capability(Capability::protocol_by_name(
462 "test.protocol.connector.ProtocolFactory",
463 ))
464 .from(&protocol_factory_server)
465 .to(Ref::parent()),
466 )
467 .await?;
468
469 Ok(builder.build().await?)
470 }
471
472 #[fuchsia::test(logging_tags = ["test_protocol_connector"])]
473 async fn test_protocol_connector() -> Result<(), anyhow::Error> {
474 let calls_made = Arc::new(AtomicU8::new(0));
475 let realm = setup_realm(calls_made.clone(), None).await?;
476 let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
477 let connector = ProtocolConnectedProtocol(realm, log_received_sender);
478
479 let error_count = Arc::new(AtomicU8::new(0));
480 let svc = ProtocolConnector::new(connector);
481 let (mut sender, fut) = svc.serve({
482 let count = error_count.clone();
483 move |e| {
484 error!("Encountered unexpected error: {:?}", e);
485 count.fetch_add(1, Ordering::SeqCst);
486 }
487 });
488
489 let _server = fasync::Task::local(fut);
490
491 for _ in 0..10 {
492 assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
493 log_received_receiver.next().await;
494 }
495
496 assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
497 assert_eq!(error_count.fetch_add(0, Ordering::SeqCst), 0);
498
499 Ok(())
500 }
501
502 #[fuchsia::test(logging_tags = ["test_protocol_reconnnect"])]
503 async fn test_protocol_reconnect() -> Result<(), anyhow::Error> {
504 let calls_made = Arc::new(AtomicU8::new(0));
505
506 let realm = setup_realm(calls_made.clone(), Some(1)).await?;
508 let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
509 let connector = ProtocolConnectedProtocol(realm, log_received_sender);
510
511 let svc = ProtocolConnector::new(connector);
512 let (mut err_send, mut err_rcv) = mpsc::channel(1);
513 let (mut sender, fut) = svc.serve(move |e| {
514 err_send.try_send(e).expect("Could not log error");
515 });
516
517 let _server = fasync::Task::local(fut);
518
519 for _ in 0..10 {
520 assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
522 log_received_receiver.next().await;
523
524 assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
526 match err_rcv.next().await.expect("Expected err") {
527 ProtocolConnectorError::ConnectionLost => {}
528 _ => {
529 assert!(false, "saw unexpected error type");
530 }
531 }
532 }
533
534 assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
535
536 Ok(())
537 }
538}