1use crate::responding_channel as responding;
6use async_utils::hanging_get::error::HangingGetServerError;
7use async_utils::stream::{StreamItem, WithEpitaph};
8use core::hash::Hash;
9use futures::channel::mpsc;
10use futures::{select, SinkExt, StreamExt};
11use std::collections::HashMap;
12
13pub const DEFAULT_CHANNEL_SIZE: usize = 128;
16
17pub struct HangingGetBroker<S, O: Unpin + 'static, F: Fn(&S, O) -> bool> {
82    inner: HangingGet<S, subscriber_key::Key, O, F>,
83    publisher: Publisher<S>,
84    updates: mpsc::Receiver<UpdateFn<S>>,
85    registrar: SubscriptionRegistrar<O>,
86    subscription_requests: responding::Receiver<(), Subscriber<O>>,
87    subscriber_key_generator: subscriber_key::Generator,
90    channel_size: usize,
91}
92
93impl<S, O, F> HangingGetBroker<S, O, F>
94where
95    S: Clone + Send,
96    O: Send + Unpin + 'static,
97    F: Fn(&S, O) -> bool,
98{
99    pub fn new(state: S, notify: F, channel_size: usize) -> Self {
104        let (sender, updates) = mpsc::channel(channel_size);
105        let publisher = Publisher { sender };
106        let (sender, subscription_requests) = responding::channel(channel_size);
107        let registrar = SubscriptionRegistrar { sender };
108        Self {
109            inner: HangingGet::new(state, notify),
110            publisher,
111            updates,
112            registrar,
113            subscription_requests,
114            subscriber_key_generator: subscriber_key::Generator::default(),
115            channel_size,
116        }
117    }
118
119    pub fn new_publisher(&self) -> Publisher<S> {
122        self.publisher.clone()
123    }
124
125    pub fn new_registrar(&self) -> SubscriptionRegistrar<O> {
128        self.registrar.clone()
129    }
130
131    pub async fn run(mut self) {
135        drop(self.publisher);
139        drop(self.registrar);
140
141        let mut subscriptions = futures::stream::SelectAll::new();
145
146        loop {
147            select! {
148                update = self.updates.next() => {
150                    if let Some(update) = update {
151                        self.inner.update(update)
152                    }
153                }
154                subscriber = self.subscription_requests.next() => {
156                    if let Some((_, responder)) = subscriber {
157                        let (sender, receiver) = responding::channel(self.channel_size);
158                        let key = self.subscriber_key_generator.next().unwrap();
159                        if let Ok(()) = responder.respond(sender.into()) {
160                            subscriptions.push(receiver.map(move |o| (key, o)).with_epitaph(key));
161                        }
162                    }
163                }
164                observer = subscriptions.next() => {
166                    match observer {
167                        Some(StreamItem::Item((key, (observer, responder)))) => {
168                            let _ = responder.respond(self.inner.subscribe(key, observer));
169                        },
170                        Some(StreamItem::Epitaph(key)) => {
171                            self.inner.unsubscribe(key);
172                        },
173                        None => (),
174                    }
175                }
176                complete => break,
178            }
179        }
180    }
181}
182
183pub struct SubscriptionRegistrar<O> {
186    sender: responding::Sender<(), Subscriber<O>>,
187}
188
189impl<O> Clone for SubscriptionRegistrar<O> {
190    fn clone(&self) -> Self {
191        Self { sender: self.sender.clone() }
192    }
193}
194
195impl<O> SubscriptionRegistrar<O> {
196    pub async fn new_subscriber(&mut self) -> Result<Subscriber<O>, HangingGetServerError> {
198        Ok(self.sender.request(()).await?)
199    }
200}
201
202pub struct Subscriber<O> {
206    sender: responding::Sender<O, Result<(), HangingGetServerError>>,
207}
208
209impl<O> From<responding::Sender<O, Result<(), HangingGetServerError>>> for Subscriber<O> {
210    fn from(sender: responding::Sender<O, Result<(), HangingGetServerError>>) -> Self {
211        Self { sender }
212    }
213}
214
215impl<O> Subscriber<O> {
216    pub async fn register(&mut self, observation: O) -> Result<(), HangingGetServerError> {
221        self.sender.request(observation).await?
222    }
223}
224
225type UpdateFn<S> = Box<dyn FnOnce(&mut S) -> bool + Send + 'static>;
228
229pub struct Publisher<S> {
232    sender: mpsc::Sender<UpdateFn<S>>,
233}
234
235impl<S> Clone for Publisher<S> {
236    fn clone(&self) -> Self {
237        Publisher { sender: self.sender.clone() }
238    }
239}
240
241impl<S> Publisher<S>
242where
243    S: Send + 'static,
244{
245    pub async fn set(&mut self, state: S) -> Result<(), HangingGetServerError> {
248        Ok(self
249            .sender
250            .send(Box::new(move |s| {
251                *s = state;
252                true
253            }))
254            .await?)
255    }
256
257    pub async fn update<F>(&mut self, update: F) -> Result<(), HangingGetServerError>
260    where
261        F: FnOnce(&mut S) -> bool + Send + 'static,
262    {
263        Ok(self.sender.send(Box::new(update)).await?)
264    }
265}
266
267pub struct HangingGet<S, K, O, F: Fn(&S, O) -> bool> {
283    state: S,
284    notify: F,
285    observers: HashMap<K, Window<O>>,
286}
287
288impl<S, K, O, F> HangingGet<S, K, O, F>
289where
290    K: Eq + Hash,
291    F: Fn(&S, O) -> bool,
292{
293    fn notify_all(&mut self) {
294        for window in self.observers.values_mut() {
295            window.notify(&self.notify, &self.state);
296        }
297    }
298
299    pub fn new(state: S, notify: F) -> Self {
303        Self { state, notify, observers: HashMap::new() }
304    }
305
306    pub fn set(&mut self, state: S) {
311        self.state = state;
312        self.notify_all();
313    }
314
315    pub fn update(&mut self, state_update: impl FnOnce(&mut S) -> bool) {
318        if state_update(&mut self.state) {
319            self.notify_all();
320        }
321    }
322
323    pub fn subscribe(&mut self, key: K, observer: O) -> Result<(), HangingGetServerError> {
331        self.observers.entry(key).or_insert_with(Window::new).observe(
332            observer,
333            &self.notify,
334            &self.state,
335        )
336    }
337
338    pub fn unsubscribe(&mut self, key: K) {
341        drop(self.observers.remove(&key));
342    }
343}
344
345struct Window<O> {
348    dirty: bool,
349    observer: Option<O>,
350}
351
352impl<O> Window<O> {
353    pub fn new() -> Self {
355        Window { dirty: true, observer: None }
356    }
357
358    pub fn observe<S>(
362        &mut self,
363        observer: O,
364        f: impl Fn(&S, O) -> bool,
365        current_state: &S,
366    ) -> Result<(), HangingGetServerError> {
367        if self.observer.is_some() {
368            return Err(HangingGetServerError::MultipleObservers);
369        }
370        self.observer = Some(observer);
371        if self.dirty {
372            self.notify(f, current_state);
373        }
374        Ok(())
375    }
376
377    pub fn notify<S>(&mut self, f: impl Fn(&S, O) -> bool, state: &S) {
381        match self.observer.take() {
382            Some(observer) => {
383                if f(state, observer) {
384                    self.dirty = false;
385                }
386            }
387            None => self.dirty = true,
388        }
389    }
390}
391
392mod subscriber_key {
394    pub struct Generator {
396        next: Key,
397    }
398
399    impl Default for Generator {
400        fn default() -> Self {
401            Self { next: Key(0) }
402        }
403    }
404
405    impl Generator {
406        pub fn next(&mut self) -> Option<Key> {
409            let key = self.next.clone();
410            if let Some(next) = self.next.0.checked_add(1) {
411                self.next.0 = next;
412                Some(key)
413            } else {
414                None
415            }
416        }
417    }
418
419    #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
421    pub struct Key(u64);
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use async_utils::hanging_get::test_util::TestObserver;
428    use fuchsia_async as fasync;
429    use futures::channel::oneshot;
430    use std::pin::pin;
431    use std::task::Poll;
432
433    const TEST_CHANNEL_SIZE: usize = 128;
434
435    #[test]
436    fn subscriber_key_generator_creates_unique_keys() {
437        let mut gen = subscriber_key::Generator::default();
438        let key1 = gen.next();
439        let key2 = gen.next();
440        assert!(key1 != key2);
441    }
442
443    #[test]
444    fn window_add_first_observer_notifies() {
445        let state = 0;
446        let mut window = Window::new();
447        window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
448    }
449
450    #[test]
451    fn window_add_second_observer_does_not_notify() {
452        let state = 0;
453        let mut window = Window::new();
454        window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
455
456        window.observe(TestObserver::expect_no_value(), TestObserver::observe, &state).unwrap();
458    }
459
460    #[test]
461    fn window_add_second_observer_notifies_after_notify_call() {
462        let mut state = 0;
463        let mut window = Window::new();
464        window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
465
466        state = 1;
467        window.notify(TestObserver::observe, &state);
468
469        window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
471    }
472
473    #[test]
474    fn window_add_multiple_observers_are_notified() {
475        let mut state = 0;
476        let mut window = Window::new();
477        window.observe(TestObserver::expect_value(state), TestObserver::observe, &state).unwrap();
478
479        let o1 = TestObserver::expect_value(1);
481        let o2 = TestObserver::expect_no_value();
482        window.observe(o1.clone(), TestObserver::observe, &state).unwrap();
483        let result = window.observe(o2.clone(), TestObserver::observe, &state);
484        assert_eq!(result.unwrap_err(), HangingGetServerError::MultipleObservers);
485        assert!(!o1.has_value());
486        state = 1;
487        window.notify(TestObserver::observe, &state);
488    }
489
490    #[test]
491    fn window_dirty_flag_state() {
492        let state = 0;
493        let mut window = Window::new();
494        let o = TestObserver::expect_value(state);
495        window.observe(o, TestObserver::observe, &state).unwrap();
496        assert!(window.observer.is_none());
497        assert!(!window.dirty);
498        window.notify(TestObserver::observe, &state);
499        assert!(window.dirty);
500        let o = TestObserver::expect_value(state);
501        window.observe(o, TestObserver::observe, &state).unwrap();
502        assert!(!window.dirty);
503    }
504
505    #[test]
506    fn window_dirty_flag_respects_consumed_flag() {
507        let state = 0;
508        let mut window = Window::new();
509
510        let o = TestObserver::expect_value(state);
511        window.observe(o, TestObserver::observe_incomplete, &state).unwrap();
512        assert!(window.dirty);
513    }
514
515    #[test]
516    fn hanging_get_subscribe() {
517        let mut hanging = HangingGet::new(0, TestObserver::observe);
518        let o = TestObserver::expect_value(0);
519        assert!(!o.has_value());
520        hanging.subscribe(0, o.clone()).unwrap();
521    }
522
523    #[test]
524    fn hanging_get_subscribe_then_set() {
525        let mut hanging = HangingGet::new(0, TestObserver::observe);
526        let o = TestObserver::expect_value(0);
527        hanging.subscribe(0, o.clone()).unwrap();
528
529        hanging.set(1);
531    }
532
533    #[test]
534    fn hanging_get_subscribe_twice_then_set() {
535        let mut hanging = HangingGet::new(0, TestObserver::observe);
536        hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
537
538        hanging.subscribe(0, TestObserver::expect_value(1)).unwrap();
539        hanging.set(1);
540    }
541
542    #[test]
543    fn hanging_get_subscribe_multiple_then_set() {
544        let mut hanging = HangingGet::new(0, TestObserver::observe);
545        hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
546
547        let o2 = TestObserver::expect_value(1);
549        hanging.subscribe(0, o2.clone()).unwrap();
550        assert!(!o2.has_value());
551
552        let _ = hanging.subscribe(0, TestObserver::expect_no_value()).unwrap_err();
554
555        hanging.set(1);
557    }
558
559    #[test]
560    fn hanging_get_subscribe_with_two_clients_then_set() {
561        let mut hanging = HangingGet::new(0, TestObserver::observe);
562        hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
563        hanging.subscribe(0, TestObserver::expect_value(1)).unwrap();
564        hanging.subscribe(1, TestObserver::expect_value(0)).unwrap();
565        hanging.subscribe(1, TestObserver::expect_value(1)).unwrap();
566        hanging.set(1);
567    }
568
569    #[test]
570    fn hanging_get_unsubscribe() {
571        let mut hanging = HangingGet::new(0, TestObserver::observe);
572        hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
573        hanging.subscribe(0, TestObserver::expect_no_value()).unwrap();
574        hanging.unsubscribe(0);
575        hanging.set(1);
576    }
577
578    #[test]
579    fn hanging_get_unsubscribe_one_of_many() {
580        let mut hanging = HangingGet::new(0, TestObserver::observe);
581
582        hanging.subscribe(0, TestObserver::expect_value(0)).unwrap();
583        hanging.subscribe(0, TestObserver::expect_no_value()).unwrap();
584        hanging.subscribe(1, TestObserver::expect_value(0)).unwrap();
585        hanging.subscribe(1, TestObserver::expect_no_value()).unwrap();
586
587        hanging.unsubscribe(0);
589        assert!(!hanging.observers.contains_key(&0));
590        assert!(hanging.observers.contains_key(&1));
591    }
592
593    #[fasync::run_until_stalled(test)]
594    async fn publisher_set_value() {
595        let (sender, mut receiver) = mpsc::channel(128);
596        let mut p = Publisher { sender };
597        let mut value = 1i32;
598        p.set(2i32).await.unwrap();
599        let f = receiver.next().await.unwrap();
600        assert_eq!(true, f(&mut value));
601        assert_eq!(value, 2);
602    }
603
604    #[fasync::run_until_stalled(test)]
605    async fn publisher_update_value() {
606        let (sender, mut receiver) = mpsc::channel(128);
607        let mut p = Publisher { sender };
608        let mut value = 1i32;
609        p.update(|v| {
610            *v += 1;
611            true
612        })
613        .await
614        .unwrap();
615        let f = receiver.next().await.unwrap();
616        assert_eq!(true, f(&mut value));
617        assert_eq!(value, 2);
618    }
619
620    #[test]
621    fn pub_sub_empty_completes() {
622        let mut ex = fasync::TestExecutor::new();
623        let broker = HangingGetBroker::new(
624            0i32,
625            |s, o: oneshot::Sender<_>| o.send(s.clone()).map(|()| true).unwrap(),
626            TEST_CHANNEL_SIZE,
627        );
628        let publisher = broker.new_publisher();
629        let registrar = broker.new_registrar();
630        let broker_future = broker.run();
631        let mut broker_future = pin!(broker_future);
632
633        assert_eq!(ex.run_until_stalled(&mut broker_future), Poll::Pending);
635
636        drop(publisher);
637        drop(registrar);
638
639        assert_eq!(ex.run_until_stalled(&mut broker_future), Poll::Ready(()));
641    }
642
643    #[fasync::run_until_stalled(test)]
644    async fn pub_sub_updates_and_observes() {
645        let broker = HangingGetBroker::new(
646            0i32,
647            |s, o: oneshot::Sender<_>| o.send(s.clone()).map(|()| true).unwrap(),
648            TEST_CHANNEL_SIZE,
649        );
650        let mut publisher = broker.new_publisher();
651        let mut registrar = broker.new_registrar();
652        let fut = async move {
653            let mut subscriber = registrar.new_subscriber().await.unwrap();
654
655            let (sender, receiver) = oneshot::channel();
657            subscriber.register(sender).await.unwrap();
658            assert_eq!(receiver.await.unwrap(), 0);
659
660            let (sender, mut receiver) = oneshot::channel();
662            subscriber.register(sender).await.unwrap();
663            assert!(receiver.try_recv().unwrap().is_none());
664            publisher.set(1).await.unwrap();
665            assert_eq!(receiver.await.unwrap(), 1);
666        };
667
668        futures::join!(fut, broker.run());
670    }
671
672    #[fasync::run_until_stalled(test)]
673    async fn pub_sub_multiple_subscribers() {
674        let broker = HangingGetBroker::new(
675            0i32,
676            |s, o: oneshot::Sender<_>| o.send(s.clone()).map(|()| true).unwrap(),
677            TEST_CHANNEL_SIZE,
678        );
679        let mut publisher = broker.new_publisher();
680        let mut registrar = broker.new_registrar();
681        let fut = async move {
682            let mut sub1 = registrar.new_subscriber().await.unwrap();
683            let mut sub2 = registrar.new_subscriber().await.unwrap();
684
685            let (sender, receiver) = oneshot::channel();
687            sub1.register(sender).await.unwrap();
688            assert_eq!(receiver.await.unwrap(), 0);
689
690            let (sender, receiver) = oneshot::channel();
691            sub2.register(sender).await.unwrap();
692            assert_eq!(receiver.await.unwrap(), 0);
693
694            let (sender, mut recv1) = oneshot::channel();
696            sub1.register(sender).await.unwrap();
697            assert!(recv1.try_recv().unwrap().is_none());
698
699            let (sender, mut recv2) = oneshot::channel();
700            sub2.register(sender).await.unwrap();
701            assert!(recv2.try_recv().unwrap().is_none());
702
703            publisher.set(1).await.unwrap();
704            assert_eq!(recv1.await.unwrap(), 1);
705            assert_eq!(recv2.await.unwrap(), 1);
706        };
707
708        futures::join!(fut, broker.run());
710    }
711}