1pub use bitflags as __bitflags;
6pub use paste;
7
8#[macro_export]
9macro_rules! atomic_bitflags {
10 (
11 $(#[$outer:meta])*
12 $vis:vis struct $BitFlags:ident: $T:ty {
13 $($t:tt)*
14 }
15 ) => {
16 $crate::paste::paste! {
17 $crate::__bitflags::bitflags! {
18 $(#[$outer])*
19 $vis struct $BitFlags: $T {
20 $($t)*
21 }
22 }
23
24 #[allow(dead_code)]
25 #[derive(Debug, Default)]
26 $vis struct [<Atomic $BitFlags>] {
27 inner: std::sync::atomic::[<Atomic $T:camel>],
28 }
29
30 #[allow(dead_code)]
31 impl [<Atomic $BitFlags>] {
32 pub fn new(initial: $BitFlags) -> Self {
33 Self {
34 inner: std::sync::atomic::[<Atomic $T:camel>]::new(initial.bits()),
35 }
36 }
37
38 pub fn load(&self, order: std::sync::atomic::Ordering) -> $BitFlags {
39 $BitFlags::from_bits_truncate(self.inner.load(order))
40 }
41
42 pub fn store(&self, val: $BitFlags, order: std::sync::atomic::Ordering) {
43 self.inner.store(val.bits(), order);
44 }
45
46 pub fn fetch_or(&self, val: $BitFlags, order: std::sync::atomic::Ordering) -> $BitFlags {
47 $BitFlags::from_bits_truncate(self.inner.fetch_or(val.bits(), order))
48 }
49
50 pub fn fetch_and(&self, val: $BitFlags, order: std::sync::atomic::Ordering) -> $BitFlags {
51 $BitFlags::from_bits_truncate(self.inner.fetch_and(val.bits(), order))
52 }
53
54 pub fn swap(&self, val: $BitFlags, order: std::sync::atomic::Ordering) -> $BitFlags {
55 $BitFlags::from_bits_truncate(self.inner.swap(val.bits(), order))
56 }
57
58 pub fn compare_exchange(
59 &self,
60 current: $BitFlags,
61 new: $BitFlags,
62 success: std::sync::atomic::Ordering,
63 failure: std::sync::atomic::Ordering,
64 ) -> Result<$BitFlags, $BitFlags> {
65 self.inner.compare_exchange(current.bits(), new.bits(), success, failure)
66 .map($BitFlags::from_bits_truncate)
67 .map_err($BitFlags::from_bits_truncate)
68 }
69
70 pub fn update(
71 &self,
72 value: $BitFlags,
73 mask: $BitFlags,
74 set_order: std::sync::atomic::Ordering,
75 fetch_order: std::sync::atomic::Ordering,
76 ) -> $BitFlags {
77 self.inner.fetch_update(set_order, fetch_order, |old| {
78 Some((old & !mask.bits()) | (value.bits() & mask.bits()))
79 }).map($BitFlags::from_bits_truncate).unwrap()
80 }
81 }
82
83 impl From<$BitFlags> for [<Atomic $BitFlags>] {
84 fn from(initial: $BitFlags) -> Self {
85 Self::new(initial)
86 }
87 }
88 }
89 };
90}
91
92#[cfg(test)]
93mod tests {
94 use std::sync::atomic::Ordering;
95
96 atomic_bitflags! {
97 #[derive(PartialEq, Eq, Debug, Clone, Copy)]
98 pub struct TestFlags: u32 {
99 const A = 1 << 0;
100 const B = 1 << 1;
101 const C = 1 << 2;
102 }
103 }
104
105 #[test]
106 fn test_atomic_bitflags() {
107 let atomic = AtomicTestFlags::new(TestFlags::A);
108 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::A);
109
110 atomic.store(TestFlags::B, Ordering::Relaxed);
111 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::B);
112
113 let prev = atomic.fetch_or(TestFlags::C, Ordering::Relaxed);
114 assert_eq!(prev, TestFlags::B);
115 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::B | TestFlags::C);
116
117 let prev = atomic.fetch_and(TestFlags::C, Ordering::Relaxed);
118 assert_eq!(prev, TestFlags::B | TestFlags::C);
119 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::C);
120 }
121
122 #[test]
123 fn test_update() {
124 let atomic = AtomicTestFlags::new(TestFlags::A | TestFlags::B);
125
126 let prev =
128 atomic.update(TestFlags::empty(), TestFlags::A, Ordering::Relaxed, Ordering::Relaxed);
129 assert_eq!(prev, TestFlags::A | TestFlags::B);
130 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::B);
131
132 let prev = atomic.update(TestFlags::A, TestFlags::A, Ordering::Relaxed, Ordering::Relaxed);
134 assert_eq!(prev, TestFlags::B);
135 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::A | TestFlags::B);
136
137 let prev = atomic.update(
139 TestFlags::empty(),
140 TestFlags::A | TestFlags::B,
141 Ordering::Relaxed,
142 Ordering::Relaxed,
143 );
144 assert_eq!(prev, TestFlags::A | TestFlags::B);
145 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::empty());
146 }
147
148 #[test]
149 fn test_from() {
150 let atomic: AtomicTestFlags = TestFlags::A.into();
151 assert_eq!(atomic.load(Ordering::Relaxed), TestFlags::A);
152 }
153}