wlancfg_lib/util/
state_machine.rs1use fuchsia_sync::RwLock;
6use futures::future::{Future, FutureExt, LocalFutureObj};
7use futures::ready;
8use futures::task::{Context, Poll};
9use std::convert::Infallible;
10use std::fmt::Debug;
11use std::pin::Pin;
12use std::sync::Arc;
13
14#[derive(Debug)]
15pub struct ExitReason(pub Result<(), anyhow::Error>);
16
17pub struct State<E>(LocalFutureObj<'static, Result<State<E>, E>>);
18
19pub struct StateMachine<E> {
20 cur_state: State<E>,
21}
22
23impl<E> Unpin for StateMachine<E> {}
24
25impl<E> Future for StateMachine<E> {
26 type Output = Result<Infallible, E>;
27
28 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
29 loop {
30 match ready!(self.cur_state.0.poll_unpin(cx)) {
31 Ok(next) => self.cur_state = next,
32 Err(e) => return Poll::Ready(Err(e)),
33 }
34 }
35 }
36}
37
38pub trait IntoStateExt<E>: Future<Output = Result<State<E>, E>> {
39 fn into_state(self) -> State<E>
40 where
41 Self: Sized + 'static,
42 {
43 State(LocalFutureObj::new(Box::new(self)))
44 }
45
46 fn into_state_machine(self) -> StateMachine<E>
47 where
48 Self: Sized + 'static,
49 {
50 StateMachine { cur_state: self.into_state() }
51 }
52}
53
54impl<F, E> IntoStateExt<E> for F where F: Future<Output = Result<State<E>, E>> {}
55
56#[derive(Clone)]
59pub struct StateMachineStatusPublisher<S>(Arc<RwLock<S>>);
60
61impl<S: Clone + Debug> StateMachineStatusPublisher<S> {
62 pub fn publish_status(&self, status: S) {
63 *self.0.write() = status;
64 }
65}
66
67#[derive(Clone)]
68pub struct StateMachineStatusReader<S>(Arc<RwLock<S>>);
69
70impl<S: Clone + Debug> StateMachineStatusReader<S> {
71 pub fn read_status(&self) -> Result<S, anyhow::Error> {
72 Ok(self.0.read().clone())
73 }
74}
75
76pub fn status_publisher_and_reader<S: Clone + Default>()
77-> (StateMachineStatusPublisher<S>, StateMachineStatusReader<S>) {
78 let status = Arc::new(RwLock::new(S::default()));
79 (StateMachineStatusPublisher(status.clone()), StateMachineStatusReader(status))
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use fuchsia_async as fasync;
86 use futures::channel::mpsc;
87 use futures::stream::StreamExt;
88 use std::mem;
89
90 #[fuchsia::test]
91 fn state_machine() {
92 let mut exec = fasync::TestExecutor::new();
93 let (sender, receiver) = mpsc::unbounded();
94 let mut state_machine = sum_state(0, receiver).into_state_machine();
95
96 assert_eq!(Poll::Pending, exec.run_until_stalled(&mut state_machine));
97
98 sender.unbounded_send(2).unwrap();
99 sender.unbounded_send(3).unwrap();
100 mem::drop(sender);
101
102 assert_eq!(Poll::Ready(Err(5)), exec.run_until_stalled(&mut state_machine));
103 }
104
105 async fn sum_state(
106 current: u32,
107 stream: mpsc::UnboundedReceiver<u32>,
108 ) -> Result<State<u32>, u32> {
109 let (number, stream) = stream.into_future().await;
110 match number {
111 Some(number) => Ok(make_sum_state(current + number, stream)),
112 None => Err(current),
113 }
114 }
115
116 fn make_sum_state(current: u32, stream: mpsc::UnboundedReceiver<u32>) -> State<u32> {
118 sum_state(current, stream).into_state()
119 }
120
121 #[derive(Clone, Debug, Default, PartialEq)]
122 enum FakeState {
123 #[default]
124 Beginning,
125 Middle,
126 End,
127 }
128
129 #[fuchsia::test]
130 fn state_publish_and_read() {
131 let _exec = fasync::TestExecutor::new();
132 let (publisher, reader) = status_publisher_and_reader::<FakeState>();
133 assert_eq!(reader.read_status().expect("failed to read status"), FakeState::Beginning);
134
135 publisher.publish_status(FakeState::Middle);
136 assert_eq!(reader.read_status().expect("failed to read status"), FakeState::Middle);
137
138 publisher.publish_status(FakeState::End);
139 assert_eq!(reader.read_status().expect("failed to read status"), FakeState::End);
140 }
141}