1use fuchsia_async::{self as fasync, DurationExt};
80
81use futures::channel::mpsc;
82use futures::future::BoxFuture;
83use futures::{Future, StreamExt};
84use log::error;
85use std::sync::atomic::{AtomicBool, Ordering};
86use std::sync::Arc;
87
88pub trait ConnectedProtocol {
90 type Protocol: fidl::endpoints::Proxy;
92
93 type ConnectError: std::fmt::Display;
95
96 type Message;
98
99 type SendError: std::fmt::Display;
101
102 fn get_protocol<'a>(&'a mut self) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>>;
107
108 fn send_message<'a>(
112 &'a mut self,
113 protocol: &'a Self::Protocol,
114 msg: Self::Message,
115 ) -> BoxFuture<'a, Result<(), Self::SendError>>;
116}
117
118#[derive(Clone, Debug)]
121pub struct ProtocolSender<Msg> {
122 sender: mpsc::Sender<Msg>,
123 is_blocked: Arc<AtomicBool>,
124}
125
126#[derive(Debug, Copy, Clone, PartialEq, Eq)]
129pub enum ProtocolSenderStatus {
130 Healthy,
132
133 BackoffStarts,
135
136 InBackoff,
138
139 BackoffEnds,
141}
142
143impl<Msg> ProtocolSender<Msg> {
144 pub fn new(sender: mpsc::Sender<Msg>) -> Self {
146 Self { sender, is_blocked: Arc::new(AtomicBool::new(false)) }
147 }
148
149 pub fn send(&mut self, message: Msg) -> ProtocolSenderStatus {
154 if self.sender.try_send(message).is_err() {
155 let was_blocked =
156 self.is_blocked.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
157 if let Ok(false) = was_blocked {
158 ProtocolSenderStatus::BackoffStarts
159 } else {
160 ProtocolSenderStatus::InBackoff
161 }
162 } else {
163 let was_blocked =
164 self.is_blocked.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
165 if let Ok(true) = was_blocked {
166 ProtocolSenderStatus::BackoffEnds
167 } else {
168 ProtocolSenderStatus::Healthy
169 }
170 }
171 }
172}
173
174struct ExponentialBackoff {
175 initial: zx::MonotonicDuration,
176 current: zx::MonotonicDuration,
177 factor: f64,
178}
179
180impl ExponentialBackoff {
181 fn new(initial: zx::MonotonicDuration, factor: f64) -> Self {
182 Self { initial, current: initial, factor }
183 }
184
185 fn next_timer(&mut self) -> fasync::Timer {
186 let timer = fasync::Timer::new(self.current.after_now());
187 self.current = zx::MonotonicDuration::from_nanos(
188 (self.current.into_nanos() as f64 * self.factor) as i64,
189 );
190 timer
191 }
192
193 fn reset(&mut self) {
194 self.current = self.initial;
195 }
196}
197
198#[derive(Debug, PartialEq, Eq)]
200pub enum ProtocolConnectorError<ConnectError, ProtocolError> {
201 ConnectFailed(ConnectError),
203
204 ConnectionLost,
206
207 ProtocolError(ProtocolError),
209}
210pub struct ProtocolConnector<CP: ConnectedProtocol> {
213 pub buffer_size: usize,
215 protocol: CP,
216}
217
218impl<CP: ConnectedProtocol> ProtocolConnector<CP> {
219 pub fn new(protocol: CP) -> Self {
221 Self::new_with_buffer_size(protocol, 10)
222 }
223
224 pub fn new_with_buffer_size(protocol: CP, buffer_size: usize) -> Self {
226 Self { buffer_size, protocol }
227 }
228
229 pub fn serve_and_log_errors(self) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
233 let protocol = <<<CP as ConnectedProtocol>::Protocol as fidl::endpoints::Proxy>::Protocol as fidl::endpoints::ProtocolMarker>::DEBUG_NAME;
234 let mut log_error = log_first_n_factory(30, move |e| error!(protocol:%; "{e}"));
235 self.serve(move |e| match e {
236 ProtocolConnectorError::ConnectFailed(e) => {
237 log_error(format!("Error obtaining a connection to the protocol: {}", e))
238 }
239 ProtocolConnectorError::ConnectionLost => {
240 log_error("Protocol disconnected, starting reconnect.".into())
241 }
242 ProtocolConnectorError::ProtocolError(e) => {
243 log_error(format!("Protocol returned an error: {}", e))
244 }
245 })
246 }
247
248 pub fn serve<ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>)>(
251 self,
252 h: ErrHandler,
253 ) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
254 let (sender, receiver) = mpsc::channel(self.buffer_size);
255 let sender = ProtocolSender::new(sender);
256 (sender, self.send_events(receiver, h))
257 }
258
259 async fn send_events<
260 ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>),
261 >(
262 mut self,
263 mut receiver: mpsc::Receiver<<CP as ConnectedProtocol>::Message>,
264 mut h: ErrHandler,
265 ) {
266 let mut backoff = ExponentialBackoff::new(zx::MonotonicDuration::from_millis(100), 2.0);
267 loop {
268 let protocol = match self.protocol.get_protocol().await {
269 Ok(protocol) => protocol,
270 Err(e) => {
271 h(ProtocolConnectorError::ConnectFailed(e));
272 backoff.next_timer().await;
273 continue;
274 }
275 };
276
277 'receiving: loop {
278 match receiver.next().await {
279 Some(message) => {
280 let resp = self.protocol.send_message(&protocol, message).await;
281 match resp {
282 Ok(_) => {
283 backoff.reset();
284 continue;
285 }
286 Err(e) => {
287 if fidl::endpoints::Proxy::is_closed(&protocol) {
288 h(ProtocolConnectorError::ConnectionLost);
289 break 'receiving;
290 } else {
291 h(ProtocolConnectorError::ProtocolError(e));
292 }
293 }
294 }
295 }
296 None => return,
297 }
298 }
299
300 backoff.next_timer().await;
301 }
302 }
303}
304
305fn log_first_n_factory(n: u64, mut log_fn: impl FnMut(String)) -> impl FnMut(String) {
306 let mut count = 0;
307 move |message| {
308 if count < n {
309 count += 1;
310 log_fn(message);
311 }
312 }
313}
314
315#[cfg(test)]
316mod test {
317 use super::*;
318 use anyhow::{format_err, Context};
319 use fidl_test_protocol_connector::{
320 ProtocolFactoryMarker, ProtocolFactoryRequest, ProtocolFactoryRequestStream, ProtocolProxy,
321 ProtocolRequest, ProtocolRequestStream,
322 };
323 use fuchsia_async as fasync;
324 use fuchsia_component::server as fserver;
325 use fuchsia_component_test::{
326 Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
327 };
328 use futures::channel::mpsc::Sender;
329 use futures::{FutureExt, TryStreamExt};
330 use std::sync::atomic::AtomicU8;
331
332 struct ProtocolConnectedProtocol(RealmInstance, Sender<()>);
333 impl ConnectedProtocol for ProtocolConnectedProtocol {
334 type Protocol = ProtocolProxy;
335 type ConnectError = anyhow::Error;
336 type Message = ();
337 type SendError = anyhow::Error;
338
339 fn get_protocol<'a>(
340 &'a mut self,
341 ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
342 async move {
343 let (protocol_proxy, server_end) = fidl::endpoints::create_proxy();
344 let protocol_factory = self
345 .0
346 .root
347 .connect_to_protocol_at_exposed_dir::<ProtocolFactoryMarker>()
348 .context("Connecting to test.protocol.ProtocolFactory failed")?;
349
350 protocol_factory
351 .create_protocol(server_end)
352 .await?
353 .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
354
355 Ok(protocol_proxy)
356 }
357 .boxed()
358 }
359
360 fn send_message<'a>(
361 &'a mut self,
362 protocol: &'a Self::Protocol,
363 _msg: (),
364 ) -> BoxFuture<'a, Result<(), Self::SendError>> {
365 async move {
366 protocol
367 .do_action()
368 .await?
369 .map_err(|e| format_err!("Failed to do action: {:?}", e))?;
370 self.1.try_send(())?;
371 Ok(())
372 }
373 .boxed()
374 }
375 }
376
377 async fn protocol_mock(
378 stream: ProtocolRequestStream,
379 calls_made: Arc<AtomicU8>,
380 close_after: Option<Arc<AtomicU8>>,
381 ) -> Result<(), anyhow::Error> {
382 stream
383 .map(|result| result.context("failed request"))
384 .try_for_each(|request| async {
385 let calls_made = calls_made.clone();
386 let close_after = close_after.clone();
387 match request {
388 ProtocolRequest::DoAction { responder } => {
389 calls_made.fetch_add(1, Ordering::SeqCst);
390 responder.send(Ok(()))?;
391 }
392 }
393
394 if let Some(ca) = &close_after {
395 if ca.fetch_sub(1, Ordering::SeqCst) == 1 {
396 return Err(format_err!("close_after triggered"));
397 }
398 }
399 Ok(())
400 })
401 .await
402 }
403
404 async fn protocol_factory_mock(
405 handles: LocalComponentHandles,
406 calls_made: Arc<AtomicU8>,
407 close_after: Option<u8>,
408 ) -> Result<(), anyhow::Error> {
409 let mut fs = fserver::ServiceFs::new();
410 let mut tasks = vec![];
411
412 fs.dir("svc").add_fidl_service(move |mut stream: ProtocolFactoryRequestStream| {
413 let calls_made = calls_made.clone();
414 tasks.push(fasync::Task::local(async move {
415 while let Some(ProtocolFactoryRequest::CreateProtocol { protocol, responder }) =
416 stream.try_next().await.expect("ProtocolFactoryRequestStream yielded an Err(_)")
417 {
418 let close_after = close_after.map(|ca| Arc::new(AtomicU8::new(ca)));
419 responder.send(Ok(())).expect("Replying to CreateProtocol caller failed");
420 let _ = protocol_mock(protocol.into_stream(), calls_made.clone(), close_after)
421 .await;
422 }
423 }));
424 });
425
426 fs.serve_connection(handles.outgoing_dir)?;
427 fs.collect::<()>().await;
428
429 Ok(())
430 }
431
432 async fn setup_realm(
433 calls_made: Arc<AtomicU8>,
434 close_after: Option<u8>,
435 ) -> Result<RealmInstance, anyhow::Error> {
436 let builder = RealmBuilder::new().await?;
437
438 let protocol_factory_server = builder
439 .add_local_child(
440 "protocol_factory",
441 move |handles: LocalComponentHandles| {
442 Box::pin(protocol_factory_mock(handles, calls_made.clone(), close_after))
443 },
444 ChildOptions::new(),
445 )
446 .await?;
447
448 builder
449 .add_route(
450 Route::new()
451 .capability(Capability::protocol_by_name(
452 "test.protocol.connector.ProtocolFactory",
453 ))
454 .from(&protocol_factory_server)
455 .to(Ref::parent()),
456 )
457 .await?;
458
459 Ok(builder.build().await?)
460 }
461
462 #[fuchsia::test(logging_tags = ["test_protocol_connector"])]
463 async fn test_protocol_connector() -> Result<(), anyhow::Error> {
464 let calls_made = Arc::new(AtomicU8::new(0));
465 let realm = setup_realm(calls_made.clone(), None).await?;
466 let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
467 let connector = ProtocolConnectedProtocol(realm, log_received_sender);
468
469 let error_count = Arc::new(AtomicU8::new(0));
470 let svc = ProtocolConnector::new(connector);
471 let (mut sender, fut) = svc.serve({
472 let count = error_count.clone();
473 move |e| {
474 error!("Encountered unexpected error: {:?}", e);
475 count.fetch_add(1, Ordering::SeqCst);
476 }
477 });
478
479 let _server = fasync::Task::local(fut);
480
481 for _ in 0..10 {
482 assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
483 log_received_receiver.next().await;
484 }
485
486 assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
487 assert_eq!(error_count.fetch_add(0, Ordering::SeqCst), 0);
488
489 Ok(())
490 }
491
492 #[fuchsia::test(logging_tags = ["test_protocol_reconnnect"])]
493 async fn test_protocol_reconnect() -> Result<(), anyhow::Error> {
494 let calls_made = Arc::new(AtomicU8::new(0));
495
496 let realm = setup_realm(calls_made.clone(), Some(1)).await?;
498 let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
499 let connector = ProtocolConnectedProtocol(realm, log_received_sender);
500
501 let svc = ProtocolConnector::new(connector);
502 let (mut err_send, mut err_rcv) = mpsc::channel(1);
503 let (mut sender, fut) = svc.serve(move |e| {
504 err_send.try_send(e).expect("Could not log error");
505 });
506
507 let _server = fasync::Task::local(fut);
508
509 for _ in 0..10 {
510 assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
512 log_received_receiver.next().await;
513
514 assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
516 match err_rcv.next().await.expect("Expected err") {
517 ProtocolConnectorError::ConnectionLost => {}
518 _ => {
519 assert!(false, "saw unexpected error type");
520 }
521 }
522 }
523
524 assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
525
526 Ok(())
527 }
528}