1use crate::responding_channel as responding;
6use async_utils::hanging_get::error::HangingGetServerError;
7use async_utils::stream::{StreamItem, WithEpitaph};
8use core::hash::Hash;
9use futures::channel::mpsc;
10use futures::{select, SinkExt, StreamExt};
11use std::collections::HashMap;
12
13pub const DEFAULT_CHANNEL_SIZE: usize = 128;
16
17pub struct HangingGetBroker<S, O: Unpin + 'static, F: Fn(&S, O) -> bool> {
82 inner: HangingGet<S, subscriber_key::Key, O, F>,
83 publisher: Publisher<S>,
84 updates: mpsc::Receiver<UpdateFn<S>>,
85 registrar: SubscriptionRegistrar<O>,
86 subscription_requests: responding::Receiver<(), Subscriber<O>>,
87 subscriber_key_generator: subscriber_key::Generator,
90 channel_size: usize,
91}
92
93impl<S, O, F> HangingGetBroker<S, O, F>
94where
95 S: Clone + Send,
96 O: Send + Unpin + 'static,
97 F: Fn(&S, O) -> bool,
98{
99 pub fn new(state: S, notify: F, channel_size: usize) -> Self {
104 let (sender, updates) = mpsc::channel(channel_size);
105 let publisher = Publisher { sender };
106 let (sender, subscription_requests) = responding::channel(channel_size);
107 let registrar = SubscriptionRegistrar { sender };
108 Self {
109 inner: HangingGet::new(state, notify),
110 publisher,
111 updates,
112 registrar,
113 subscription_requests,
114 subscriber_key_generator: subscriber_key::Generator::default(),
115 channel_size,
116 }
117 }
118
119 pub fn new_publisher(&self) -> Publisher<S> {
122 self.publisher.clone()
123 }
124
125 pub fn new_registrar(&self) -> SubscriptionRegistrar<O> {
128 self.registrar.clone()
129 }
130
131 pub async fn run(mut self) {
135 drop(self.publisher);
139 drop(self.registrar);
140
141 let mut subscriptions = futures::stream::SelectAll::new();
145
146 loop {
147 select! {
148 update = self.updates.next() => {
150 if let Some(update) = update {
151 self.inner.update(update)
152 }
153 }
154 subscriber = self.subscription_requests.next() => {
156 if let Some((_, responder)) = subscriber {
157 let (sender, receiver) = responding::channel(self.channel_size);
158 let key = self.subscriber_key_generator.next().unwrap();
159 if let Ok(()) = responder.respond(sender.into()) {
160 subscriptions.push(receiver.map(move |o| (key, o)).with_epitaph(key));
161 }
162 }
163 }
164 observer = subscriptions.next() => {
166 match observer {
167 Some(StreamItem::Item((key, (observer, responder)))) => {
168 let _ = responder.respond(self.inner.subscribe(key, observer));
169 },
170 Some(StreamItem::Epitaph(key)) => {
171 self.inner.unsubscribe(key);
172 },
173 None => (),
174 }
175 }
176 complete => break,
178 }
179 }
180 }
181}
182
183pub struct SubscriptionRegistrar<O> {
186 sender: responding::Sender<(), Subscriber<O>>,
187}
188
189impl<O> Clone for SubscriptionRegistrar<O> {
190 fn clone(&self) -> Self {
191 Self { sender: self.sender.clone() }
192 }
193}
194
195impl<O> SubscriptionRegistrar<O> {
196 pub async fn new_subscriber(&mut self) -> Result<Subscriber<O>, HangingGetServerError> {
198 Ok(self.sender.request(()).await?)
199 }
200}
201
202pub struct Subscriber<O> {
206 sender: responding::Sender<O, Result<(), HangingGetServerError>>,
207}
208
209impl<O> From<responding::Sender<O, Result<(), HangingGetServerError>>> for Subscriber<O> {
210 fn from(sender: responding::Sender<O, Result<(), HangingGetServerError>>) -> Self {
211 Self { sender }
212 }
213}
214
215impl<O> Subscriber<O> {
216 pub async fn register(&mut self, observation: O) -> Result<(), HangingGetServerError> {
221 self.sender.request(observation).await?
222 }
223}
224
225type UpdateFn<S> = Box<dyn FnOnce(&mut S) -> bool + Send + 'static>;
228
229pub struct Publisher<S> {
232 sender: mpsc::Sender<UpdateFn<S>>,
233}
234
235impl<S> Clone for Publisher<S> {
236 fn clone(&self) -> Self {
237 Publisher { sender: self.sender.clone() }
238 }
239}
240
241impl<S> Publisher<S>
242where
243 S: Send + 'static,
244{
245 pub async fn set(&mut self, state: S) -> Result<(), HangingGetServerError> {
248 Ok(self
249 .sender
250 .send(Box::new(move |s| {
251 *s = state;
252 true
253 }))
254 .await?)
255 }
256
257 pub async fn update<F>(&mut self, update: F) -> Result<(), HangingGetServerError>
260 where
261 F: FnOnce(&mut S) -> bool + Send + 'static,
262 {
263 Ok(self.sender.send(Box::new(update)).await?)
264 }
265}
266
267pub struct HangingGet<S, K, O, F: Fn(&S, O) -> bool> {
283 state: S,
284 notify: F,
285 observers: HashMap<K, Window<O>>,
286}
287
288impl<S, K, O, F> HangingGet<S, K, O, F>
289where
290 K: Eq + Hash,
291 F: Fn(&S, O) -> bool,
292{
293 fn notify_all(&mut self) {
294 for window in self.observers.values_mut() {
295 window.notify(&self.notify, &self.state);
296 }
297 }
298
299 pub fn new(state: S, notify: F) -> Self {
303 Self { state, notify, observers: HashMap::new() }
304 }
305
306 pub fn set(&mut self, state: S) {
311 self.state = state;
312 self.notify_all();
313 }
314
315 pub fn update(&mut self, state_update: impl FnOnce(&mut S) -> bool) {
318 if state_update(&mut self.state) {
319 self.notify_all();
320 }
321 }
322
323 pub fn subscribe(&mut self, key: K, observer: O) -> Result<(), HangingGetServerError> {
331 self.observers.entry(key).or_insert_with(Window::new).observe(
332 observer,
333 &self.notify,
334 &self.state,
335 )
336 }
337
338 pub fn unsubscribe(&mut self, key: K) {
341 drop(self.observers.remove(&key));
342 }
343}
344
345struct Window<O> {
348 dirty: bool,
349 observer: Option<O>,
350}
351
352impl<O> Window<O> {
353 pub fn new() -> Self {
355 Window { dirty: true, observer: None }
356 }
357
358 pub fn observe<S>(
362 &mut self,
363 observer: O,
364 f: impl Fn(&S, O) -> bool,
365 current_state: &S,
366 ) -> Result<(), HangingGetServerError> {
367 if self.observer.is_some() {
368 return Err(HangingGetServerError::MultipleObservers);
369 }
370 self.observer = Some(observer);
371 if self.dirty {
372 self.notify(f, current_state);
373 }
374 Ok(())
375 }
376
377 pub fn notify<S>(&mut self, f: impl Fn(&S, O) -> bool, state: &S) {
381 match self.observer.take() {
382 Some(observer) => {
383 if f(state, observer) {
384 self.dirty = false;
385 }
386 }
387 None => self.dirty = true,
388 }
389 }
390}
391
392mod subscriber_key {
394 pub struct Generator {
396 next: Key,
397 }
398
399 impl Default for Generator {
400 fn default() -> Self {
401 Self { next: Key(0) }
402 }
403 }
404
405 impl Generator {
406 pub fn next(&mut self) -> Option<Key> {
409 let key = self.next.clone();
410 if let Some(next) = self.next.0.checked_add(1) {
411 self.next.0 = next;
412 Some(key)
413 } else {
414 None
415 }
416 }
417 }
418
419 #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
421 pub struct Key(u64);
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use async_utils::hanging_get::test_util::TestObserver;
428 use fuchsia_async as fasync;
429 use futures::channel::oneshot;
430 use std::pin::pin;
431 use std::task::Poll;
432
433 const TEST_CHANNEL_SIZE: usize = 128;
434
435 #[test]
436 fn subscriber_key_generator_creates_unique_keys() {
437 let mut gen = subscriber_key::Generator::default();
438 let key1 = gen.next();
439 let key2 = gen.next();
440 assert!(key1 != key2);
441 }
442
443 #[test]
444 fn window_add_first_observer_notifies() {
445 let state = 0;
446 let mut window = Window::new();
447 window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
448 }
449
450 #[test]
451 fn window_add_second_observer_does_not_notify() {
452 let state = 0;
453 let mut window = Window::new();
454 window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
455
456 window.observe(TestObserver::expect_no_value(), TestObserver::observe, &state).unwrap();
458 }
459
460 #[test]
461 fn window_add_second_observer_notifies_after_notify_call() {
462 let mut state = 0;
463 let mut window = Window::new();
464 window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
465
466 state = 1;
467 window.notify(TestObserver::observe, &state);
468
469 window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
471 }
472
473 #[test]
474 fn window_add_multiple_observers_are_notified() {
475 let mut state = 0;
476 let mut window = Window::new();
477 window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
478
479 let o1 = TestObserver::expect_value(1);
481 let o2 = TestObserver::expect_no_value();
482 window.observe(o1.clone(), TestObserver::observe, &state).unwrap();
483 let result = window.observe(o2.clone(), TestObserver::observe, &state);
484 assert_eq!(result.unwrap_err(), HangingGetServerError::MultipleObservers);
485 assert!(!o1.has_value());
486 state = 1;
487 window.notify(TestObserver::observe, &state);
488 }
489
490 #[test]
491 fn window_dirty_flag_state() {
492 let state = 0;
493 let mut window = Window::new();
494 let o = TestObserver::expect_value(state);
495 window.observe(o, TestObserver::observe, &state).unwrap();
496 assert!(window.observer.is_none());
497 assert!(!window.dirty);
498 window.notify(TestObserver::observe, &state);
499 assert!(window.dirty);
500 let o = TestObserver::expect_value(state);
501 window.observe(o, TestObserver::observe, &state).unwrap();
502 assert!(!window.dirty);
503 }
504
505 #[test]
506 fn window_dirty_flag_respects_consumed_flag() {
507 let state = 0;
508 let mut window = Window::new();
509
510 let o = TestObserver::expect_value(state);
511 window.observe(o, TestObserver::observe_incomplete, &state).unwrap();
512 assert!(window.dirty);
513 }
514
515 #[test]
516 fn hanging_get_subscribe() {
517 let mut hanging = HangingGet::new(0, TestObserver::observe);
518 let o = TestObserver::expect_value(0);
519 assert!(!o.has_value());
520 hanging.subscribe(0, o.clone()).unwrap();
521 }
522
523 #[test]
524 fn hanging_get_subscribe_then_set() {
525 let mut hanging = HangingGet::new(0, TestObserver::observe);
526 let o = TestObserver::expect_value(0);
527 hanging.subscribe(0, o.clone()).unwrap();
528
529 hanging.set(1);
531 }
532
533 #[test]
534 fn hanging_get_subscribe_twice_then_set() {
535 let mut hanging = HangingGet::new(0, TestObserver::observe);
536 hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
537
538 hanging.subscribe(0, TestObserver::expect_value(1)).unwrap();
539 hanging.set(1);
540 }
541
542 #[test]
543 fn hanging_get_subscribe_multiple_then_set() {
544 let mut hanging = HangingGet::new(0, TestObserver::observe);
545 hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
546
547 let o2 = TestObserver::expect_value(1);
549 hanging.subscribe(0, o2.clone()).unwrap();
550 assert!(!o2.has_value());
551
552 let _ = hanging.subscribe(0, TestObserver::expect_no_value()).unwrap_err();
554
555 hanging.set(1);
557 }
558
559 #[test]
560 fn hanging_get_subscribe_with_two_clients_then_set() {
561 let mut hanging = HangingGet::new(0, TestObserver::observe);
562 hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
563 hanging.subscribe(0, TestObserver::expect_value(1)).unwrap();
564 hanging.subscribe(1, TestObserver::expect_value(0)).unwrap();
565 hanging.subscribe(1, TestObserver::expect_value(1)).unwrap();
566 hanging.set(1);
567 }
568
569 #[test]
570 fn hanging_get_unsubscribe() {
571 let mut hanging = HangingGet::new(0, TestObserver::observe);
572 hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
573 hanging.subscribe(0, TestObserver::expect_no_value()).unwrap();
574 hanging.unsubscribe(0);
575 hanging.set(1);
576 }
577
578 #[test]
579 fn hanging_get_unsubscribe_one_of_many() {
580 let mut hanging = HangingGet::new(0, TestObserver::observe);
581
582 hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
583 hanging.subscribe(0, TestObserver::expect_no_value()).unwrap();
584 hanging.subscribe(1, TestObserver::expect_value(0)).unwrap();
585 hanging.subscribe(1, TestObserver::expect_no_value()).unwrap();
586
587 hanging.unsubscribe(0);
589 assert!(!hanging.observers.contains_key(&0));
590 assert!(hanging.observers.contains_key(&1));
591 }
592
593 #[fasync::run_until_stalled(test)]
594 async fn publisher_set_value() {
595 let (sender, mut receiver) = mpsc::channel(128);
596 let mut p = Publisher { sender };
597 let mut value = 1i32;
598 p.set(2i32).await.unwrap();
599 let f = receiver.next().await.unwrap();
600 assert_eq!(true, f(&mut value));
601 assert_eq!(value, 2);
602 }
603
604 #[fasync::run_until_stalled(test)]
605 async fn publisher_update_value() {
606 let (sender, mut receiver) = mpsc::channel(128);
607 let mut p = Publisher { sender };
608 let mut value = 1i32;
609 p.update(|v| {
610 *v += 1;
611 true
612 })
613 .await
614 .unwrap();
615 let f = receiver.next().await.unwrap();
616 assert_eq!(true, f(&mut value));
617 assert_eq!(value, 2);
618 }
619
620 #[test]
621 fn pub_sub_empty_completes() {
622 let mut ex = fasync::TestExecutor::new();
623 let broker = HangingGetBroker::new(
624 0i32,
625 |s, o: oneshot::Sender<_>| o.send(s.clone()).map(|()| true).unwrap(),
626 TEST_CHANNEL_SIZE,
627 );
628 let publisher = broker.new_publisher();
629 let registrar = broker.new_registrar();
630 let broker_future = broker.run();
631 let mut broker_future = pin!(broker_future);
632
633 assert_eq!(ex.run_until_stalled(&mut broker_future), Poll::Pending);
635
636 drop(publisher);
637 drop(registrar);
638
639 assert_eq!(ex.run_until_stalled(&mut broker_future), Poll::Ready(()));
641 }
642
643 #[fasync::run_until_stalled(test)]
644 async fn pub_sub_updates_and_observes() {
645 let broker = HangingGetBroker::new(
646 0i32,
647 |s, o: oneshot::Sender<_>| o.send(s.clone()).map(|()| true).unwrap(),
648 TEST_CHANNEL_SIZE,
649 );
650 let mut publisher = broker.new_publisher();
651 let mut registrar = broker.new_registrar();
652 let fut = async move {
653 let mut subscriber = registrar.new_subscriber().await.unwrap();
654
655 let (sender, receiver) = oneshot::channel();
657 subscriber.register(sender).await.unwrap();
658 assert_eq!(receiver.await.unwrap(), 0);
659
660 let (sender, mut receiver) = oneshot::channel();
662 subscriber.register(sender).await.unwrap();
663 assert!(receiver.try_recv().unwrap().is_none());
664 publisher.set(1).await.unwrap();
665 assert_eq!(receiver.await.unwrap(), 1);
666 };
667
668 futures::join!(fut, broker.run());
670 }
671
672 #[fasync::run_until_stalled(test)]
673 async fn pub_sub_multiple_subscribers() {
674 let broker = HangingGetBroker::new(
675 0i32,
676 |s, o: oneshot::Sender<_>| o.send(s.clone()).map(|()| true).unwrap(),
677 TEST_CHANNEL_SIZE,
678 );
679 let mut publisher = broker.new_publisher();
680 let mut registrar = broker.new_registrar();
681 let fut = async move {
682 let mut sub1 = registrar.new_subscriber().await.unwrap();
683 let mut sub2 = registrar.new_subscriber().await.unwrap();
684
685 let (sender, receiver) = oneshot::channel();
687 sub1.register(sender).await.unwrap();
688 assert_eq!(receiver.await.unwrap(), 0);
689
690 let (sender, receiver) = oneshot::channel();
691 sub2.register(sender).await.unwrap();
692 assert_eq!(receiver.await.unwrap(), 0);
693
694 let (sender, mut recv1) = oneshot::channel();
696 sub1.register(sender).await.unwrap();
697 assert!(recv1.try_recv().unwrap().is_none());
698
699 let (sender, mut recv2) = oneshot::channel();
700 sub2.register(sender).await.unwrap();
701 assert!(recv2.try_recv().unwrap().is_none());
702
703 publisher.set(1).await.unwrap();
704 assert_eq!(recv1.await.unwrap(), 1);
705 assert_eq!(recv2.await.unwrap(), 1);
706 };
707
708 futures::join!(fut, broker.run());
710 }
711}