starnix_lifecycle/
atomic_counter.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Helper class to implement a counter that can be shared across threads.
6
7/// Macro to define an atomic counter for a given base type. This is necessary because rust atomic
8/// types are not parametrized on their base type.
9macro_rules! atomic_counter_definition {
10    ($ty:ty) => {
11        paste::paste! {
12        #[derive(Debug, Default)]
13        pub struct [< Atomic $ty:camel Counter >](std::sync::atomic::[< Atomic $ty:camel >]);
14
15        #[allow(dead_code)]
16        impl [< Atomic $ty:camel Counter >] {
17            pub const fn new(value: $ty) -> Self {
18                Self(std::sync::atomic::[< Atomic $ty:camel >]::new(value))
19            }
20
21            pub fn next(&self) -> $ty {
22                self.add(1)
23            }
24
25            pub fn add(&self, amount: $ty) -> $ty {
26                self.0.fetch_add(amount, std::sync::atomic::Ordering::Relaxed)
27            }
28
29            pub fn get(&self) -> $ty {
30                self.0.load(std::sync::atomic::Ordering::Relaxed)
31            }
32            pub fn reset(&mut self, value: $ty) {
33                *self.0.get_mut() = value;
34            }
35        }
36
37        impl From<$ty> for [< Atomic $ty:camel Counter >] {
38            fn from(value: $ty) -> Self {
39                Self::new(value)
40            }
41        }
42        }
43    };
44}
45
46atomic_counter_definition!(u64);
47atomic_counter_definition!(u32);
48atomic_counter_definition!(usize);
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use std::sync::Arc;
54
55    #[::fuchsia::test]
56    fn test_new() {
57        const COUNTER: AtomicU64Counter = AtomicU64Counter::new(0);
58        assert_eq!(COUNTER.get(), 0);
59    }
60
61    #[::fuchsia::test]
62    fn test_one_thread() {
63        let mut counter = AtomicU64Counter::default();
64        assert_eq!(counter.get(), 0);
65        assert_eq!(counter.add(5), 0);
66        assert_eq!(counter.get(), 5);
67        assert_eq!(counter.next(), 5);
68        assert_eq!(counter.get(), 6);
69        counter.reset(2);
70        assert_eq!(counter.get(), 2);
71        assert_eq!(counter.next(), 2);
72        assert_eq!(counter.get(), 3);
73    }
74
75    #[::fuchsia::test]
76    fn test_multiple_thread() {
77        const THREADS_COUNT: u64 = 10;
78        const INC_ITERATIONS: u64 = 1000;
79        let mut thread_handles = Vec::new();
80        let counter = Arc::new(AtomicU64Counter::default());
81
82        for _ in 0..THREADS_COUNT {
83            thread_handles.push(std::thread::spawn({
84                let counter = Arc::clone(&counter);
85                move || {
86                    for _ in 0..INC_ITERATIONS {
87                        counter.next();
88                    }
89                }
90            }));
91        }
92        for handle in thread_handles {
93            handle.join().expect("join");
94        }
95        assert_eq!(THREADS_COUNT * INC_ITERATIONS, counter.get());
96    }
97}