Skip to main content

atomic_bitflags/
lib.rs

1// Copyright 2026 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5pub use bitflags as __bitflags;
6pub use paste;
7
8#[macro_export]
9macro_rules! atomic_bitflags {
10    (
11        $(#[$outer:meta])*
12        $vis:vis struct $BitFlags:ident: $T:ty {
13            $($t:tt)*
14        }
15    ) => {
16        $crate::paste::paste! {
17            $crate::__bitflags::bitflags! {
18                $(#[$outer])*
19                $vis struct $BitFlags: $T {
20                    $($t)*
21                }
22            }
23
24            #[allow(dead_code)]
25            #[derive(Debug, Default)]
26            $vis struct [<Atomic $BitFlags>] {
27                inner: std::sync::atomic::[<Atomic $T:camel>],
28            }
29
30            #[allow(dead_code)]
31            impl [<Atomic $BitFlags>] {
32                pub fn new(initial: $BitFlags) -> Self {
33                    Self {
34                        inner: std::sync::atomic::[<Atomic $T:camel>]::new(initial.bits()),
35                    }
36                }
37
38                pub fn load(&self, order: std::sync::atomic::Ordering) -> $BitFlags {
39                    $BitFlags::from_bits_truncate(self.inner.load(order))
40                }
41
42                pub fn store(&self, val: $BitFlags, order: std::sync::atomic::Ordering) {
43                    self.inner.store(val.bits(), order);
44                }
45
46                pub fn fetch_or(&self, val: $BitFlags, order: std::sync::atomic::Ordering) -> $BitFlags {
47                    $BitFlags::from_bits_truncate(self.inner.fetch_or(val.bits(), order))
48                }
49
50                pub fn fetch_and(&self, val: $BitFlags, order: std::sync::atomic::Ordering) -> $BitFlags {
51                    $BitFlags::from_bits_truncate(self.inner.fetch_and(val.bits(), order))
52                }
53
54                pub fn swap(&self, val: $BitFlags, order: std::sync::atomic::Ordering) -> $BitFlags {
55                    $BitFlags::from_bits_truncate(self.inner.swap(val.bits(), order))
56                }
57
58                pub fn compare_exchange(
59                    &self,
60                    current: $BitFlags,
61                    new: $BitFlags,
62                    success: std::sync::atomic::Ordering,
63                    failure: std::sync::atomic::Ordering,
64                ) -> Result<$BitFlags, $BitFlags> {
65                    self.inner.compare_exchange(current.bits(), new.bits(), success, failure)
66                        .map($BitFlags::from_bits_truncate)
67                        .map_err($BitFlags::from_bits_truncate)
68                }
69
70                pub fn update(
71                    &self,
72                    value: $BitFlags,
73                    mask: $BitFlags,
74                    set_order: std::sync::atomic::Ordering,
75                    fetch_order: std::sync::atomic::Ordering,
76                ) -> $BitFlags {
77                    self.inner.fetch_update(set_order, fetch_order, |old| {
78                        Some((old & !mask.bits()) | (value.bits() & mask.bits()))
79                    }).map($BitFlags::from_bits_truncate).unwrap()
80                }
81            }
82
83            impl From<$BitFlags> for [<Atomic $BitFlags>] {
84                fn from(initial: $BitFlags) -> Self {
85                    Self::new(initial)
86                }
87            }
88        }
89    };
90}
91
92#[cfg(test)]
93mod tests {
94    use std::sync::atomic::Ordering;
95
96    atomic_bitflags! {
97        #[derive(PartialEq, Eq, Debug, Clone, Copy)]
98        pub struct TestFlags: u32 {
99            const A = 1 << 0;
100            const B = 1 << 1;
101            const C = 1 << 2;
102        }
103    }
104
105    #[test]
106    fn test_atomic_bitflags() {
107        let atomic = AtomicTestFlags::new(TestFlags::A);
108        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::A);
109
110        atomic.store(TestFlags::B, Ordering::Relaxed);
111        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::B);
112
113        let prev = atomic.fetch_or(TestFlags::C, Ordering::Relaxed);
114        assert_eq!(prev, TestFlags::B);
115        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::B | TestFlags::C);
116
117        let prev = atomic.fetch_and(TestFlags::C, Ordering::Relaxed);
118        assert_eq!(prev, TestFlags::B | TestFlags::C);
119        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::C);
120    }
121
122    #[test]
123    fn test_update() {
124        let atomic = AtomicTestFlags::new(TestFlags::A | TestFlags::B);
125
126        // Update A to 0, leaving B as is. Mask is A. Value is 0.
127        let prev =
128            atomic.update(TestFlags::empty(), TestFlags::A, Ordering::Relaxed, Ordering::Relaxed);
129        assert_eq!(prev, TestFlags::A | TestFlags::B);
130        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::B);
131
132        // Update A to 1, leaving B as is. Mask is A. Value is A.
133        let prev = atomic.update(TestFlags::A, TestFlags::A, Ordering::Relaxed, Ordering::Relaxed);
134        assert_eq!(prev, TestFlags::B);
135        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::A | TestFlags::B);
136
137        // Update B to 0, A to 0. Mask is A | B. Value is 0.
138        let prev = atomic.update(
139            TestFlags::empty(),
140            TestFlags::A | TestFlags::B,
141            Ordering::Relaxed,
142            Ordering::Relaxed,
143        );
144        assert_eq!(prev, TestFlags::A | TestFlags::B);
145        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::empty());
146    }
147
148    #[test]
149    fn test_from() {
150        let atomic: AtomicTestFlags = TestFlags::A.into();
151        assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::A);
152    }
153}