starnix_lifecycle/
atomic_counter.rs1macro_rules! atomic_counter_definition {
10 ($ty:ty) => {
11 paste::paste! {
12 #[derive(Debug, Default)]
13 pub struct [< Atomic $ty:camel Counter >](std::sync::atomic::[< Atomic $ty:camel >]);
14
15 #[allow(dead_code)]
16 impl [< Atomic $ty:camel Counter >] {
17 pub const fn new(value: $ty) -> Self {
18 Self(std::sync::atomic::[< Atomic $ty:camel >]::new(value))
19 }
20
21 pub fn next(&self) -> $ty {
22 self.add(1)
23 }
24
25 pub fn add(&self, amount: $ty) -> $ty {
26 self.0.fetch_add(amount, std::sync::atomic::Ordering::Relaxed)
27 }
28
29 pub fn get(&self) -> $ty {
30 self.0.load(std::sync::atomic::Ordering::Relaxed)
31 }
32 pub fn reset(&mut self, value: $ty) {
33 *self.0.get_mut() = value;
34 }
35 }
36
37 impl From<$ty> for [< Atomic $ty:camel Counter >] {
38 fn from(value: $ty) -> Self {
39 Self::new(value)
40 }
41 }
42 }
43 };
44}
45
46atomic_counter_definition!(u64);
47atomic_counter_definition!(u32);
48atomic_counter_definition!(usize);
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use std::sync::Arc;
54
55 #[::fuchsia::test]
56 fn test_new() {
57 const COUNTER: AtomicU64Counter = AtomicU64Counter::new(0);
58 assert_eq!(COUNTER.get(), 0);
59 }
60
61 #[::fuchsia::test]
62 fn test_one_thread() {
63 let mut counter = AtomicU64Counter::default();
64 assert_eq!(counter.get(), 0);
65 assert_eq!(counter.add(5), 0);
66 assert_eq!(counter.get(), 5);
67 assert_eq!(counter.next(), 5);
68 assert_eq!(counter.get(), 6);
69 counter.reset(2);
70 assert_eq!(counter.get(), 2);
71 assert_eq!(counter.next(), 2);
72 assert_eq!(counter.get(), 3);
73 }
74
75 #[::fuchsia::test]
76 fn test_multiple_thread() {
77 const THREADS_COUNT: u64 = 10;
78 const INC_ITERATIONS: u64 = 1000;
79 let mut thread_handles = Vec::new();
80 let counter = Arc::new(AtomicU64Counter::default());
81
82 for _ in 0..THREADS_COUNT {
83 thread_handles.push(std::thread::spawn({
84 let counter = Arc::clone(&counter);
85 move || {
86 for _ in 0..INC_ITERATIONS {
87 counter.next();
88 }
89 }
90 }));
91 }
92 for handle in thread_handles {
93 handle.join().expect("join");
94 }
95 assert_eq!(THREADS_COUNT * INC_ITERATIONS, counter.get());
96 }
97}