starnix_lifecycle/
atomic_counter.rs1use starnix_types::atomic::{AsAtomic, AtomicOperations};
8use std::sync::atomic::Ordering;
9
10#[derive(Debug)]
12pub struct AtomicCounter<T: AsAtomic>(T::Atomic);
13
14impl<T: AsAtomic> AtomicCounter<T> {
15 pub fn new(value: T) -> Self {
16 Self(T::Atomic::new(value))
17 }
18
19 pub fn next(&self) -> T {
20 self.add(T::ONE)
21 }
22
23 pub fn add(&self, amount: T) -> T {
24 self.0.fetch_add(amount, Ordering::Relaxed)
25 }
26
27 pub fn get(&self) -> T {
28 self.0.load(Ordering::Relaxed)
29 }
30
31 pub fn reset(&mut self, value: T) {
32 self.0.store(value, Ordering::Relaxed);
33 }
34}
35
36impl AtomicCounter<u32> {
37 pub const fn new_const(value: u32) -> Self {
38 Self(std::sync::atomic::AtomicU32::new(value))
39 }
40}
41
42impl AtomicCounter<usize> {
43 pub const fn new_const(value: usize) -> Self {
44 Self(std::sync::atomic::AtomicUsize::new(value))
45 }
46}
47
48impl<T: AsAtomic> Default for AtomicCounter<T>
49where
50 T: Default,
51{
52 fn default() -> Self {
53 Self::new(T::default())
54 }
55}
56
57impl<T: AsAtomic> From<T> for AtomicCounter<T> {
58 fn from(value: T) -> Self {
59 Self::new(value)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use std::sync::Arc;
67
68 #[::fuchsia::test]
69 fn test_new() {
70 let counter: AtomicCounter<u64> = AtomicCounter::<u64>::new(0);
71 assert_eq!(counter.get(), 0);
72 }
73
74 #[::fuchsia::test]
75 fn test_one_thread() {
76 let mut counter = AtomicCounter::<u64>::default();
77 assert_eq!(counter.get(), 0);
78 assert_eq!(counter.add(5), 0);
79 assert_eq!(counter.get(), 5);
80 assert_eq!(counter.next(), 5);
81 assert_eq!(counter.get(), 6);
82 counter.reset(2);
83 assert_eq!(counter.get(), 2);
84 assert_eq!(counter.next(), 2);
85 assert_eq!(counter.get(), 3);
86 }
87
88 #[::fuchsia::test]
89 fn test_multiple_thread() {
90 const THREADS_COUNT: u64 = 10;
91 const INC_ITERATIONS: u64 = 1000;
92 let mut thread_handles = Vec::new();
93 let counter = Arc::new(AtomicCounter::<u64>::default());
94
95 for _ in 0..THREADS_COUNT {
96 thread_handles.push(std::thread::spawn({
97 let counter = Arc::clone(&counter);
98 move || {
99 for _ in 0..INC_ITERATIONS {
100 counter.next();
101 }
102 }
103 }));
104 }
105 for handle in thread_handles {
106 handle.join().expect("join");
107 }
108 assert_eq!(THREADS_COUNT * INC_ITERATIONS, counter.get());
109 }
110}