wlan_statemachine/
lib.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Generic state machine implementation with compile time checked state transitions.
6
7use std::fmt::Debug;
8use std::marker::PhantomData;
9use std::ops::{Deref, DerefMut};
10
11pub use wlan_statemachine_macro::statemachine;
12
13/// Wrapper to safely replace states of state machine which don't consume their states.
14/// Use this wrapper if state transitions are performed on mutable references rather than consumed
15/// states.
16/// Example:
17/// ```
18/// fn on_event(event: Event, statemachine: &mut StateMachine<Foo>) {
19///     statemachine.replace_state(|state| match state {
20///         State::A(_) => match event {
21///             Event::A => State::B,
22///             _ => state,
23///         }
24///         State::B => {
25///             warn!("cannot receive events in State::B");
26///             state
27///         }
28///     })
29/// }
30/// ```
31pub struct StateMachine<S> {
32    state: Option<S>,
33}
34impl<S: Debug> Debug for StateMachine<S> {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "State: {:?}", self.state)
37    }
38}
39impl<S: PartialEq> PartialEq for StateMachine<S> {
40    fn eq(&self, other: &Self) -> bool {
41        self.state == other.state
42    }
43}
44
45impl<S> StateMachine<S> {
46    /// Constructs a new `StateMachine`.
47    pub fn new(state: S) -> Self {
48        StateMachine { state: Some(state) }
49    }
50
51    /// Replaces the current state with one lazily constructed by `map`.
52    pub fn replace_state<F>(&mut self, map: F) -> &mut Self
53    where
54        F: FnOnce(S) -> S,
55    {
56        // Safe to unwrap: `state` can never be None.
57        self.state = Some(map(self.state.take().unwrap()));
58        self
59    }
60
61    /// Replaces the current state with one lazily constructed by `map`.
62    pub fn try_replace_state<F, E>(&mut self, map: F) -> Result<&mut Self, E>
63    where
64        F: FnOnce(S) -> Result<S, E>,
65        S: Debug,
66    {
67        // Safe to unwrap: `state` can never be None.
68        self.state = Some(map(self.state.take().unwrap())?);
69        Ok(self)
70    }
71
72    /// Replaces the current state with `new_state`.
73    pub fn replace_state_with(&mut self, new_state: S) -> &mut Self {
74        self.state = Some(new_state);
75        self
76    }
77
78    /// Consumes the state machine and returns its current state.
79    pub fn into_state(self) -> S {
80        // Safe to unwrap: `state` can never be None.
81        self.state.unwrap()
82    }
83}
84impl<S> AsRef<S> for StateMachine<S> {
85    fn as_ref(&self) -> &S {
86        // Safe to unwrap: `state` can never be None.
87        &self.state.as_ref().unwrap()
88    }
89}
90impl<S> AsMut<S> for StateMachine<S> {
91    fn as_mut(&mut self) -> &mut S {
92        // Safe to unwrap: `state` can never be None.
93        self.state.as_mut().unwrap()
94    }
95}
96impl<S> Deref for StateMachine<S> {
97    type Target = S;
98
99    fn deref(&self) -> &Self::Target {
100        self.as_ref()
101    }
102}
103impl<S> DerefMut for StateMachine<S> {
104    fn deref_mut(&mut self) -> &mut Self::Target {
105        self.as_mut()
106    }
107}
108
109/// A `StateTransition` defines valid transitions from one state into another.
110/// Implement `StateTransition` on the given `State` struct to define a new
111/// state transition. Alternatively, use the convenience macro
112/// `statemachine!`.
113pub trait StateTransition<S> {
114    #[doc(hidden)]
115    fn __internal_transition_to(new_state: S) -> State<S>;
116}
117
118/// Marker for creating a new initial state.
119/// This trait enforces that only the initial state can be created manually while all others must
120/// be created through a proper state transition.
121pub trait InitialState {}
122
123/// Wrapper struct for a state S. Use in combination with `StateTransition`.
124pub struct State<S> {
125    pub data: S,
126    // Prevent public from constructing a State while granting access to `data` for partial
127    // matching multiple states.
128    __internal_phantom: PhantomData<S>,
129}
130impl<S> State<S> {
131    /// Construct the initial state of a state machine.
132    pub fn new(data: S) -> State<S>
133    where
134        S: InitialState,
135    {
136        Self::__internal_new(data)
137    }
138
139    // Note: must be public to be accessible through `statemachine!` macro.
140    #[doc(hidden)]
141    pub fn __internal_new(data: S) -> State<S> {
142        Self { data, __internal_phantom: PhantomData }
143    }
144
145    /// Releases the captured state data `S` and provides a transition instance
146    /// to perform a compile time checked state transition.
147    /// Use this function when the state data `S` is shared between multiple
148    /// states.
149    pub fn release_data(self) -> (Transition<S>, S) {
150        (Transition { _phantom: PhantomData }, self.data)
151    }
152
153    pub fn transition_to<T>(self, new_state: T) -> State<T>
154    where
155        S: StateTransition<T>,
156    {
157        S::__internal_transition_to(new_state)
158    }
159
160    pub fn apply<T, E>(self, transition: T) -> E
161    where
162        T: MultiTransition<E, S>,
163    {
164        transition.from(self)
165    }
166}
167
168/// Convenience functions for unit testing.
169/// Note: Do ONLY use in tests!
170pub mod testing {
171    use super::*;
172
173    /// Creates a new State with the given data.
174    pub fn new_state<S>(data: S) -> State<S> {
175        State::<S>::__internal_new(data)
176    }
177}
178impl<S> Deref for State<S> {
179    type Target = S;
180
181    fn deref(&self) -> &Self::Target {
182        &self.data
183    }
184}
185impl<S> DerefMut for State<S> {
186    fn deref_mut(&mut self) -> &mut Self::Target {
187        &mut self.data
188    }
189}
190
191impl<S: Debug> Debug for State<S> {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        write!(f, "State data: {:?}", self.data)
194    }
195}
196impl<S: PartialEq> PartialEq for State<S> {
197    fn eq(&self, other: &Self) -> bool {
198        self.data == other.data
199    }
200}
201
202/// Wrapper struct to enforce compile time checked state transitions of one state into another.
203pub struct Transition<S> {
204    _phantom: PhantomData<S>,
205}
206
207pub trait MultiTransition<E, S> {
208    fn from(self, state: State<S>) -> E;
209    fn via(self, transition: Transition<S>) -> E;
210}
211
212impl<S> Transition<S> {
213    pub fn to<T>(self, new_state: T) -> State<T>
214    where
215        S: StateTransition<T>,
216    {
217        S::__internal_transition_to(new_state)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[derive(Default, Debug)]
226    struct SharedStateData {
227        foo: u8,
228    }
229
230    pub struct A;
231    pub struct B(SharedStateData);
232    pub struct C(SharedStateData);
233    statemachine!(
234        enum States,
235
236        () => A,
237        A => B,
238        B => [C, A],
239        C => [A],
240        // Test duplicate transitions
241        B => C,
242        B => [A, C],
243    );
244
245    fn multi_transition(foo: u8) -> BTransition {
246        match foo {
247            0 => BTransition::ToA(A),
248            _ => BTransition::ToC(C(SharedStateData::default())),
249        }
250    }
251
252    #[derive(Debug)]
253    pub struct A2;
254    #[derive(Debug)]
255    pub struct B2;
256    statemachine!(
257        // Test derive attribute.
258        #[derive(Debug)]
259        enum States2,
260        () => A2,
261        A2 => B2,
262        A2 => B2, // Test duplicate transitions
263    );
264
265    #[derive(Debug)]
266    pub struct NonGen;
267    #[derive(Debug)]
268    pub struct Gen1<E>(E);
269    #[derive(Debug)]
270    pub struct Gen2<'a, F>(&'a Vec<F>);
271    #[derive(Debug)]
272    pub struct Gen3<'a, E, F>(E, Gen2<'a, F>);
273    statemachine!(
274        // Complicated test with generics and lifetimes.
275        #[derive(Debug)]
276        enum States3<'a, E, F>,
277        () => Gen1<E>,
278        Gen1<E> => [NonGen, Gen1<E>, Gen2<'a, F>, Gen3<'a, E, F>],
279        NonGen => [NonGen, Gen1<E>],
280        Gen2<'a, F> => Gen1<E>,
281        Gen3<'a, E, F> => Gen2<'a, F>,
282    );
283
284    #[test]
285    fn state_transitions() {
286        let state = State::new(A);
287        // Regular state transition:
288        let state = state.transition_to(B(SharedStateData::default()));
289
290        // Modify and share state data with new state.
291        let (transition, mut data) = state.release_data();
292        data.0.foo = 5;
293        let state = transition.to(C(data.0));
294        assert_eq!(state.0.foo, 5);
295    }
296
297    #[test]
298    fn state_transition_self_transition() {
299        let state = State::new(A);
300
301        let state = state.transition_to(B(SharedStateData { foo: 5 }));
302        let (transition, data) = state.release_data();
303        assert_eq!(data.0.foo, 5);
304
305        let state = transition.to(B(SharedStateData { foo: 2 }));
306        let (_, data) = state.release_data();
307        assert_eq!(data.0.foo, 2);
308    }
309
310    #[test]
311    fn statemachine() {
312        let mut statemachine = StateMachine::new(States::A(State::new(A)));
313        statemachine.replace_state(|state| match state {
314            States::A(state) => state.transition_to(B(SharedStateData::default())).into(),
315            _ => state,
316        });
317
318        match statemachine.into_state() {
319            States::B(State { data: B(SharedStateData { foo: 0 }), .. }) => (),
320            _ => panic!("unexpected state"),
321        }
322    }
323
324    #[test]
325    fn statemachine_succeeds_try_replace_state() {
326        #[derive(Debug)]
327        struct Error;
328
329        let mut statemachine = StateMachine::new(States2::A2(State::new(A2)));
330        statemachine
331            .try_replace_state(|state| match state {
332                States2::A2(state) => Ok(state.transition_to(B2).into()),
333                _ => Err(Error),
334            })
335            .expect("Failed to transition to B2");
336
337        match statemachine.into_state() {
338            States2::B2(_) => (),
339            _ => panic!("unexpected state"),
340        }
341    }
342    #[test]
343    fn statemachine_fails_try_replace_state() {
344        #[derive(Debug)]
345        struct Error;
346
347        let mut statemachine = StateMachine::new(States2::A2(State::new(A2)));
348
349        statemachine
350            .try_replace_state(|state| match state {
351                _ => Err(Error),
352            })
353            .expect_err("try_replace_state() unexpectedly succeeded");
354    }
355
356    #[test]
357    fn transition_enums() {
358        let state = State::new(A).transition_to(B(SharedStateData::default()));
359        let transition = multi_transition(0);
360        match state.apply(transition) {
361            States::A(_) => (),
362            _ => panic!("expected transition into A"),
363        };
364    }
365
366    #[test]
367    fn transition_enums_release() {
368        let state = State::new(A).transition_to(B(SharedStateData::default()));
369        let (transition, _data) = state.release_data();
370
371        let target = multi_transition(0);
372        match target.via(transition) {
373            States::A(_) => (),
374            _ => panic!("expected transition into A"),
375        };
376    }
377
378    #[test]
379    fn transition_enums_branching() {
380        let state = State::new(A).transition_to(B(SharedStateData::default()));
381        let (transition, _data) = state.release_data();
382
383        let target = multi_transition(1);
384        match target.via(transition) {
385            States::C(_) => (),
386            _ => panic!("expected transition into C"),
387        };
388    }
389
390    #[test]
391    fn generated_enum() {
392        let _state_machine: States2 = match States2::A2(State::new(A2)) {
393            // Test generated From impls:
394            States2::A2(state) => state.transition_to(B2).into(),
395            other => panic!("expected state A to be active: {:?}", other),
396        };
397
398        // No assertion needed. This test verifies that the enum struct "States2" was generated
399        // properly.
400    }
401
402    #[test]
403    fn generic_state_transitions() {
404        let test_vec = vec![10, 20, 30];
405
406        // Construct a generic state.
407        let state = State::new(Gen1("test"));
408        // Transition to a generic state with a lifetime.
409        let state = state.transition_to(Gen2(&test_vec));
410        // Store data with a lifetime, then reinsert into a later state.
411        let (transition, data) = state.release_data();
412        let state = transition.to(Gen1("test2"));
413        let (transition, data2) = state.release_data();
414        assert_eq!(data2.0, "test2");
415        let state = transition.to(Gen3(data2.0, data));
416        assert_eq!((state.1).0, &test_vec);
417    }
418
419    #[test]
420    fn generated_generic_enum() {
421        let _state_machine: States3<'_, &str, u16> =
422            match States3::<&str, u16>::Gen1(State::new(Gen1("test"))) {
423                // Test generated From impls:
424                States3::Gen1(state) => state.transition_to(NonGen).into(),
425                other => panic!("expected state A to be active: {:?}", other),
426            };
427
428        // No assertion needed. This test verifies that the enum struct "States2" was generated
429        // properly.
430    }
431}