netstack3_sync/
lib.rs
1#![warn(missing_docs, unreachable_patterns, unused)]
8
9extern crate alloc;
10
11#[cfg(loom)]
12pub(crate) use loom::sync;
13#[cfg(not(loom))]
14pub(crate) use std::sync;
15
16use net_types::ip::{GenericOverIp, Ip};
17
18pub use sync::atomic;
20
21pub mod rc;
22
23#[derive(Debug, Default)]
25pub struct Mutex<T>(sync::Mutex<T>);
26
27pub type LockGuard<'a, T> = lock_guard::LockGuard<'a, Mutex<T>, sync::MutexGuard<'a, T>>;
29
30impl<T> Mutex<T> {
31 pub fn new(t: T) -> Mutex<T> {
33 Mutex(sync::Mutex::new(t))
34 }
35
36 #[inline]
45 #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
46 pub fn lock(&self) -> LockGuard<'_, T> {
47 lock_guard::LockGuard::new(self, |Self(m)| m.lock().expect("unexpectedly poisoned"))
48 }
49
50 #[inline]
52 pub fn into_inner(self) -> T {
53 let Self(mutex) = self;
54 mutex.into_inner().expect("unexpectedly poisoned")
55 }
56
57 #[inline]
62 #[cfg(not(loom))]
65 pub fn get_mut(&mut self) -> &mut T {
66 self.0.get_mut().expect("unexpectedly poisoned")
67 }
68}
69
70impl<T: 'static> lock_order::lock::ExclusiveLock<T> for Mutex<T> {
71 type Guard<'l> = LockGuard<'l, T>;
72
73 fn lock(&self) -> Self::Guard<'_> {
74 self.lock()
75 }
76}
77
78impl<T, I: Ip> GenericOverIp<I> for Mutex<T>
79where
80 T: GenericOverIp<I>,
81{
82 type Type = Mutex<T::Type>;
83}
84
85#[derive(Debug, Default)]
87pub struct RwLock<T>(sync::RwLock<T>);
88
89pub type RwLockReadGuard<'a, T> =
91 lock_guard::LockGuard<'a, RwLock<T>, sync::RwLockReadGuard<'a, T>>;
92
93pub type RwLockWriteGuard<'a, T> =
95 lock_guard::LockGuard<'a, RwLock<T>, sync::RwLockWriteGuard<'a, T>>;
96
97impl<T> RwLock<T> {
98 pub fn new(t: T) -> RwLock<T> {
100 RwLock(sync::RwLock::new(t))
101 }
102
103 #[inline]
113 #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
114 pub fn read(&self) -> RwLockReadGuard<'_, T> {
115 lock_guard::LockGuard::new(self, |Self(rw)| rw.read().expect("unexpectedly poisoned"))
116 }
117
118 #[inline]
128 #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
129 pub fn write(&self) -> RwLockWriteGuard<'_, T> {
130 lock_guard::LockGuard::new(self, |Self(rw)| rw.write().expect("unexpectedly poisoned"))
131 }
132
133 #[inline]
135 pub fn into_inner(self) -> T {
136 let Self(rwlock) = self;
137 rwlock.into_inner().expect("unexpectedly poisoned")
138 }
139
140 #[inline]
145 #[cfg(not(loom))]
148 pub fn get_mut(&mut self) -> &mut T {
149 self.0.get_mut().expect("unexpectedly poisoned")
150 }
151}
152
153impl<T: 'static> lock_order::lock::ReadWriteLock<T> for RwLock<T> {
154 type ReadGuard<'l> = RwLockReadGuard<'l, T>;
155
156 type WriteGuard<'l> = RwLockWriteGuard<'l, T>;
157
158 fn read_lock(&self) -> Self::ReadGuard<'_> {
159 self.read()
160 }
161
162 fn write_lock(&self) -> Self::WriteGuard<'_> {
163 self.write()
164 }
165}
166
167impl<T, I: Ip> GenericOverIp<I> for RwLock<T>
168where
169 T: GenericOverIp<I>,
170{
171 type Type = RwLock<T::Type>;
172}
173
174mod lock_guard {
175 #[cfg(not(feature = "recursive-lock-panic"))]
176 use core::marker::PhantomData;
177 use core::ops::{Deref, DerefMut};
178
179 #[cfg(feature = "recursive-lock-panic")]
180 use crate::lock_tracker::LockTracker;
181
182 pub struct LockGuard<'a, L, G> {
187 guard: G,
188
189 #[cfg(feature = "recursive-lock-panic")]
192 _lock_tracker: LockTracker<'a, L>,
193 #[cfg(not(feature = "recursive-lock-panic"))]
194 _marker: PhantomData<&'a L>,
195 }
196
197 impl<'a, L, G> LockGuard<'a, L, G> {
198 #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
200 pub fn new<F: FnOnce(&'a L) -> G>(lock: &'a L, lock_fn: F) -> Self {
201 #[cfg(feature = "recursive-lock-panic")]
202 let lock_tracker = LockTracker::new(lock);
203
204 Self {
205 guard: lock_fn(lock),
206
207 #[cfg(feature = "recursive-lock-panic")]
208 _lock_tracker: lock_tracker,
209 #[cfg(not(feature = "recursive-lock-panic"))]
210 _marker: PhantomData,
211 }
212 }
213 }
214
215 impl<L, G: Deref> Deref for LockGuard<'_, L, G> {
216 type Target = G::Target;
217
218 fn deref(&self) -> &G::Target {
219 self.guard.deref()
220 }
221 }
222
223 impl<L, G: DerefMut> DerefMut for LockGuard<'_, L, G> {
224 fn deref_mut(&mut self) -> &mut G::Target {
225 self.guard.deref_mut()
226 }
227 }
228}
229
230#[cfg(feature = "recursive-lock-panic")]
231mod lock_tracker {
232 use core::cell::RefCell;
233 use core::panic::Location;
234 use std::collections::HashMap;
235
236 std::thread_local! {
237 static HELD_LOCKS: RefCell<HashMap<*const usize, &'static Location<'static>>> =
238 RefCell::new(HashMap::new());
239 }
240
241 pub(crate) struct LockTracker<'a, L>(&'a L);
246
247 impl<'a, L> LockTracker<'a, L> {
248 #[track_caller]
256 pub(crate) fn new(lock: &'a L) -> Self {
257 {
258 let ptr = lock as *const _ as *const _;
259 match HELD_LOCKS.with(|l| l.borrow_mut().insert(ptr, Location::caller())) {
260 None => {}
261 Some(prev_lock_caller) => {
262 panic!("lock already held; ptr = {:p}\n{}", ptr, prev_lock_caller)
263 }
264 }
265 }
266
267 Self(lock)
268 }
269 }
270
271 impl<L> Drop for LockTracker<'_, L> {
272 fn drop(&mut self) {
273 let Self(lock) = self;
274 let ptr = *lock as *const _ as *const _;
275 assert_ne!(
276 HELD_LOCKS.with(|l| l.borrow_mut().remove(&ptr)),
277 None,
278 "must have previously been locked; ptr = {:p}",
279 ptr
280 );
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 use std::thread;
290
291 #[test]
292 fn mutex_lock_and_write() {
293 let m = Mutex::<u32>::new(0);
294 {
295 let mut guard = m.lock();
296 assert_eq!(*guard, 0);
297 *guard = 5;
298 }
299
300 {
301 let guard = m.lock();
302 assert_eq!(*guard, 5);
303 }
304 }
305
306 #[test]
307 fn mutex_lock_from_different_threads() {
308 const NUM_THREADS: u32 = 4;
309
310 let m = Mutex::<u32>::new(u32::MAX);
311 let m = &m;
312
313 thread::scope(|s| {
314 for i in 0..NUM_THREADS {
315 let _: thread::ScopedJoinHandle<'_, _> = s.spawn(move || {
316 let prev = {
317 let mut guard = m.lock();
318 let prev = *guard;
319 *guard = i;
320 prev
321 };
322
323 assert!(prev == u32::MAX || prev < NUM_THREADS);
324 });
325 }
326 });
327
328 let guard = m.lock();
329 assert!(*guard < NUM_THREADS);
330 }
331
332 #[test]
333 #[should_panic(expected = "lock already held")]
334 #[cfg(feature = "recursive-lock-panic")]
335 fn mutex_double_lock_panic() {
336 let m = Mutex::<u32>::new(0);
337 let _ok_guard = m.lock();
338 let _panic_guard = m.lock();
339 }
340
341 #[test]
342 fn rwlock_read_lock() {
343 let rw = RwLock::<u32>::new(0);
344
345 {
346 let guard = rw.read();
347 assert_eq!(*guard, 0);
348 }
349
350 {
351 let guard = rw.read();
352 assert_eq!(*guard, 0);
353 }
354 }
355
356 #[test]
357 fn rwlock_write_lock() {
358 let rw = RwLock::<u32>::new(0);
359 {
360 let mut guard = rw.write();
361 assert_eq!(*guard, 0);
362 *guard = 5;
363 }
364
365 {
366 let guard = rw.write();
367 assert_eq!(*guard, 5);
368 }
369 }
370
371 #[test]
372 fn rwlock_read_and_write_from_different_threads() {
373 const NUM_THREADS: u32 = 4;
374
375 let rw = RwLock::<u32>::new(u32::MAX);
376 let rw = &rw;
377
378 thread::scope(|s| {
379 for i in 0..NUM_THREADS {
380 let _: thread::ScopedJoinHandle<'_, _> = s.spawn(move || {
381 let prev = if i % 2 == 0 {
382 let mut guard = rw.write();
384 let prev = *guard;
385 *guard = i;
386 prev
387 } else {
388 let guard = rw.read();
389 *guard
390 };
391
392 assert!(prev == u32::MAX || (prev < NUM_THREADS && prev % 2 == 0));
393 });
394 }
395 });
396
397 let val = *rw.read();
398 assert!(val < NUM_THREADS && val % 2 == 0);
399 }
400
401 #[test]
402 #[cfg_attr(feature = "recursive-lock-panic", should_panic(expected = "lock already held"))]
403 fn mutex_double_read() {
404 let rw = RwLock::<u32>::new(0);
405 let ok_guard = rw.read();
406 assert_eq!(*ok_guard, 0);
407 let maybe_panic_guard = rw.read();
408 assert_eq!(*maybe_panic_guard, 0);
409 }
410
411 #[test]
412 #[should_panic(expected = "lock already held")]
413 #[cfg(feature = "recursive-lock-panic")]
414 fn mutex_double_write_panic() {
415 let rw = RwLock::<u32>::new(0);
416 let _ok_guard = rw.write();
417 let _panic_guard = rw.write();
418 }
419
420 #[test]
421 #[should_panic(expected = "lock already held")]
422 #[cfg(feature = "recursive-lock-panic")]
423 fn mutex_double_read_then_write_panic() {
424 let rw = RwLock::<u32>::new(0);
425 let _ok_guard = rw.read();
426 let _panic_guard = rw.write();
427 }
428
429 #[test]
430 #[should_panic(expected = "lock already held")]
431 #[cfg(feature = "recursive-lock-panic")]
432 fn mutex_double_write_then_read_panic() {
433 let rw = RwLock::<u32>::new(0);
434 let _ok_guard = rw.read();
435 let _panic_guard = rw.write();
436 }
437}