wlancfg_lib/util/
state_machine.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_sync::RwLock;
6use futures::future::{Future, FutureExt, LocalFutureObj};
7use futures::ready;
8use futures::task::{Context, Poll};
9use std::convert::Infallible;
10use std::fmt::Debug;
11use std::pin::Pin;
12use std::sync::Arc;
13
14#[derive(Debug)]
15pub struct ExitReason(pub Result<(), anyhow::Error>);
16
17pub struct State<E>(LocalFutureObj<'static, Result<State<E>, E>>);
18
19pub struct StateMachine<E> {
20    cur_state: State<E>,
21}
22
23impl<E> Unpin for StateMachine<E> {}
24
25impl<E> Future for StateMachine<E> {
26    type Output = Result<Infallible, E>;
27
28    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
29        loop {
30            match ready!(self.cur_state.0.poll_unpin(cx)) {
31                Ok(next) => self.cur_state = next,
32                Err(e) => return Poll::Ready(Err(e)),
33            }
34        }
35    }
36}
37
38pub trait IntoStateExt<E>: Future<Output = Result<State<E>, E>> {
39    fn into_state(self) -> State<E>
40    where
41        Self: Sized + 'static,
42    {
43        State(LocalFutureObj::new(Box::new(self)))
44    }
45
46    fn into_state_machine(self) -> StateMachine<E>
47    where
48        Self: Sized + 'static,
49    {
50        StateMachine { cur_state: self.into_state() }
51    }
52}
53
54impl<F, E> IntoStateExt<E> for F where F: Future<Output = Result<State<E>, E>> {}
55
56// Some helpers to allow state machines to publish state and other futures to check in on the most
57// recent state updates.
58#[derive(Clone)]
59pub struct StateMachineStatusPublisher<S>(Arc<RwLock<S>>);
60
61impl<S: Clone + Debug> StateMachineStatusPublisher<S> {
62    pub fn publish_status(&self, status: S) {
63        *self.0.write() = status;
64    }
65}
66
67#[derive(Clone)]
68pub struct StateMachineStatusReader<S>(Arc<RwLock<S>>);
69
70impl<S: Clone + Debug> StateMachineStatusReader<S> {
71    pub fn read_status(&self) -> Result<S, anyhow::Error> {
72        Ok(self.0.read().clone())
73    }
74}
75
76pub fn status_publisher_and_reader<S: Clone + Default>()
77-> (StateMachineStatusPublisher<S>, StateMachineStatusReader<S>) {
78    let status = Arc::new(RwLock::new(S::default()));
79    (StateMachineStatusPublisher(status.clone()), StateMachineStatusReader(status))
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use fuchsia_async as fasync;
86    use futures::channel::mpsc;
87    use futures::stream::StreamExt;
88    use std::mem;
89
90    #[fuchsia::test]
91    fn state_machine() {
92        let mut exec = fasync::TestExecutor::new();
93        let (sender, receiver) = mpsc::unbounded();
94        let mut state_machine = sum_state(0, receiver).into_state_machine();
95
96        assert_eq!(Poll::Pending, exec.run_until_stalled(&mut state_machine));
97
98        sender.unbounded_send(2).unwrap();
99        sender.unbounded_send(3).unwrap();
100        mem::drop(sender);
101
102        assert_eq!(Poll::Ready(Err(5)), exec.run_until_stalled(&mut state_machine));
103    }
104
105    async fn sum_state(
106        current: u32,
107        stream: mpsc::UnboundedReceiver<u32>,
108    ) -> Result<State<u32>, u32> {
109        let (number, stream) = stream.into_future().await;
110        match number {
111            Some(number) => Ok(make_sum_state(current + number, stream)),
112            None => Err(current),
113        }
114    }
115
116    // A workaround for the "recursive impl Trait" problem in the compiler
117    fn make_sum_state(current: u32, stream: mpsc::UnboundedReceiver<u32>) -> State<u32> {
118        sum_state(current, stream).into_state()
119    }
120
121    #[derive(Clone, Debug, Default, PartialEq)]
122    enum FakeState {
123        #[default]
124        Beginning,
125        Middle,
126        End,
127    }
128
129    #[fuchsia::test]
130    fn state_publish_and_read() {
131        let _exec = fasync::TestExecutor::new();
132        let (publisher, reader) = status_publisher_and_reader::<FakeState>();
133        assert_eq!(reader.read_status().expect("failed to read status"), FakeState::Beginning);
134
135        publisher.publish_status(FakeState::Middle);
136        assert_eq!(reader.read_status().expect("failed to read status"), FakeState::Middle);
137
138        publisher.publish_status(FakeState::End);
139        assert_eq!(reader.read_status().expect("failed to read status"), FakeState::End);
140    }
141}