Skip to main content

starnix_lifecycle/
atomic_counter.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Helper class to implement a counter that can be shared across threads.
6
7use starnix_types::atomic::{AsAtomic, AtomicOperations};
8use std::sync::atomic::Ordering;
9
10/// A generic atomic counter.
11#[derive(Debug)]
12pub struct AtomicCounter<T: AsAtomic>(T::Atomic);
13
14impl<T: AsAtomic> AtomicCounter<T> {
15    pub fn new(value: T) -> Self {
16        Self(T::Atomic::new(value))
17    }
18
19    pub fn next(&self) -> T {
20        self.add(T::ONE)
21    }
22
23    pub fn add(&self, amount: T) -> T {
24        self.0.fetch_add(amount, Ordering::Relaxed)
25    }
26
27    pub fn get(&self) -> T {
28        self.0.load(Ordering::Relaxed)
29    }
30
31    pub fn reset(&mut self, value: T) {
32        self.0.store(value, Ordering::Relaxed);
33    }
34}
35
36impl AtomicCounter<u32> {
37    pub const fn new_const(value: u32) -> Self {
38        Self(std::sync::atomic::AtomicU32::new(value))
39    }
40}
41
42impl AtomicCounter<usize> {
43    pub const fn new_const(value: usize) -> Self {
44        Self(std::sync::atomic::AtomicUsize::new(value))
45    }
46}
47
48impl<T: AsAtomic> Default for AtomicCounter<T>
49where
50    T: Default,
51{
52    fn default() -> Self {
53        Self::new(T::default())
54    }
55}
56
57impl<T: AsAtomic> From<T> for AtomicCounter<T> {
58    fn from(value: T) -> Self {
59        Self::new(value)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use std::sync::Arc;
67
68    #[::fuchsia::test]
69    fn test_new() {
70        let counter: AtomicCounter<u64> = AtomicCounter::<u64>::new(0);
71        assert_eq!(counter.get(), 0);
72    }
73
74    #[::fuchsia::test]
75    fn test_one_thread() {
76        let mut counter = AtomicCounter::<u64>::default();
77        assert_eq!(counter.get(), 0);
78        assert_eq!(counter.add(5), 0);
79        assert_eq!(counter.get(), 5);
80        assert_eq!(counter.next(), 5);
81        assert_eq!(counter.get(), 6);
82        counter.reset(2);
83        assert_eq!(counter.get(), 2);
84        assert_eq!(counter.next(), 2);
85        assert_eq!(counter.get(), 3);
86    }
87
88    #[::fuchsia::test]
89    fn test_multiple_thread() {
90        const THREADS_COUNT: u64 = 10;
91        const INC_ITERATIONS: u64 = 1000;
92        let mut thread_handles = Vec::new();
93        let counter = Arc::new(AtomicCounter::<u64>::default());
94
95        for _ in 0..THREADS_COUNT {
96            thread_handles.push(std::thread::spawn({
97                let counter = Arc::clone(&counter);
98                move || {
99                    for _ in 0..INC_ITERATIONS {
100                        counter.next();
101                    }
102                }
103            }));
104        }
105        for handle in thread_handles {
106            handle.join().expect("join");
107        }
108        assert_eq!(THREADS_COUNT * INC_ITERATIONS, counter.get());
109    }
110}