netstack3_sync/
lib.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Synchronization primitives for Netstack3.
6
7#![warn(missing_docs, unreachable_patterns, unused)]
8
9extern crate alloc;
10
11#[cfg(loom)]
12pub(crate) use loom::sync;
13#[cfg(not(loom))]
14pub(crate) use std::sync;
15
16use net_types::ip::{GenericOverIp, Ip};
17
18// Re-export atomics module for loom compatibility in appropriate targets.
19pub use sync::atomic;
20
21pub mod rc;
22
23/// A [`sync::Mutex`] assuming lock poisoning will never occur.
24#[derive(Debug, Default)]
25pub struct Mutex<T>(sync::Mutex<T>);
26
27/// Lock guard for access to a [`Mutex`].
28pub type LockGuard<'a, T> = lock_guard::LockGuard<'a, Mutex<T>, sync::MutexGuard<'a, T>>;
29
30impl<T> Mutex<T> {
31    /// Creates a new mutex in an unlocked state ready for use.
32    pub fn new(t: T) -> Mutex<T> {
33        Mutex(sync::Mutex::new(t))
34    }
35
36    /// Acquires a mutex, blocking the current thread until it is able to do so.
37    ///
38    /// See [`sync::Mutex::lock`] for more details.
39    ///
40    /// # Panics
41    ///
42    /// This method may panic if the calling thread is already holding the
43    /// lock.
44    #[inline]
45    #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
46    pub fn lock(&self) -> LockGuard<'_, T> {
47        lock_guard::LockGuard::new(self, |Self(m)| m.lock().expect("unexpectedly poisoned"))
48    }
49
50    /// Consumes this mutex, returning the underlying data.
51    #[inline]
52    pub fn into_inner(self) -> T {
53        let Self(mutex) = self;
54        mutex.into_inner().expect("unexpectedly poisoned")
55    }
56
57    /// Returns a mutable reference to the underlying data.
58    ///
59    /// Since this call borrows the [`Mutex`] mutably, no actual locking needs
60    /// to take place. See [`sync::Mutex::get_mut`] for more details.
61    #[inline]
62    // TODO(https://github.com/tokio-rs/loom/pull/322): remove the disable for
63    // loom once loom's lock type supports the method.
64    #[cfg(not(loom))]
65    pub fn get_mut(&mut self) -> &mut T {
66        self.0.get_mut().expect("unexpectedly poisoned")
67    }
68}
69
70impl<T: 'static> lock_order::lock::ExclusiveLock<T> for Mutex<T> {
71    type Guard<'l> = LockGuard<'l, T>;
72
73    fn lock(&self) -> Self::Guard<'_> {
74        self.lock()
75    }
76}
77
78impl<T, I: Ip> GenericOverIp<I> for Mutex<T>
79where
80    T: GenericOverIp<I>,
81{
82    type Type = Mutex<T::Type>;
83}
84
85/// A [`sync::RwLock`] assuming lock poisoning will never occur.
86#[derive(Debug, Default)]
87pub struct RwLock<T>(sync::RwLock<T>);
88
89/// Lock guard for read access to a [`RwLock`].
90pub type RwLockReadGuard<'a, T> =
91    lock_guard::LockGuard<'a, RwLock<T>, sync::RwLockReadGuard<'a, T>>;
92
93/// Lock guard for write access to a [`RwLock`].
94pub type RwLockWriteGuard<'a, T> =
95    lock_guard::LockGuard<'a, RwLock<T>, sync::RwLockWriteGuard<'a, T>>;
96
97impl<T> RwLock<T> {
98    /// Creates a new instance of an `RwLock<T>` which is unlocked.
99    pub fn new(t: T) -> RwLock<T> {
100        RwLock(sync::RwLock::new(t))
101    }
102
103    /// Locks this rwlock with shared read access, blocking the current thread
104    /// until it can be acquired.
105    ///
106    /// See [`sync::RwLock::read`] for more details.
107    ///
108    /// # Panics
109    ///
110    /// This method may panic if the calling thread already holds the read or
111    /// write lock.
112    #[inline]
113    #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
114    pub fn read(&self) -> RwLockReadGuard<'_, T> {
115        lock_guard::LockGuard::new(self, |Self(rw)| rw.read().expect("unexpectedly poisoned"))
116    }
117
118    /// Locks this rwlock with exclusive write access, blocking the current
119    /// thread until it can be acquired.
120    ///
121    /// See [`sync::RwLock::write`] for more details.
122    ///
123    /// # Panics
124    ///
125    /// This method may panic if the calling thread already holds the read or
126    /// write lock.
127    #[inline]
128    #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
129    pub fn write(&self) -> RwLockWriteGuard<'_, T> {
130        lock_guard::LockGuard::new(self, |Self(rw)| rw.write().expect("unexpectedly poisoned"))
131    }
132
133    /// Consumes this rwlock, returning the underlying data.
134    #[inline]
135    pub fn into_inner(self) -> T {
136        let Self(rwlock) = self;
137        rwlock.into_inner().expect("unexpectedly poisoned")
138    }
139
140    /// Returns a mutable reference to the underlying data.
141    ///
142    /// Since this call borrows the [`RwLock`] mutably, no actual locking needs
143    /// to take place. See [`sync::RwLock::get_mut`] for more details.
144    #[inline]
145    // TODO(https://github.com/tokio-rs/loom/pull/322): remove the disable for
146    // loom once loom's lock type supports the method.
147    #[cfg(not(loom))]
148    pub fn get_mut(&mut self) -> &mut T {
149        self.0.get_mut().expect("unexpectedly poisoned")
150    }
151}
152
153impl<T: 'static> lock_order::lock::ReadWriteLock<T> for RwLock<T> {
154    type ReadGuard<'l> = RwLockReadGuard<'l, T>;
155
156    type WriteGuard<'l> = RwLockWriteGuard<'l, T>;
157
158    fn read_lock(&self) -> Self::ReadGuard<'_> {
159        self.read()
160    }
161
162    fn write_lock(&self) -> Self::WriteGuard<'_> {
163        self.write()
164    }
165}
166
167impl<T, I: Ip> GenericOverIp<I> for RwLock<T>
168where
169    T: GenericOverIp<I>,
170{
171    type Type = RwLock<T::Type>;
172}
173
174mod lock_guard {
175    #[cfg(not(feature = "recursive-lock-panic"))]
176    use core::marker::PhantomData;
177    use core::ops::{Deref, DerefMut};
178
179    #[cfg(feature = "recursive-lock-panic")]
180    use crate::lock_tracker::LockTracker;
181
182    /// An RAII implementation used to release a lock when dropped.
183    ///
184    /// Wraps inner guard to provide lock instrumentation (when the appropriate
185    /// feature is enabled).
186    pub struct LockGuard<'a, L, G> {
187        guard: G,
188
189        // Placed after `guard` so that the tracker's destructor is run (and the
190        // unlock is tracked) after the lock is actually unlocked.
191        #[cfg(feature = "recursive-lock-panic")]
192        _lock_tracker: LockTracker<'a, L>,
193        #[cfg(not(feature = "recursive-lock-panic"))]
194        _marker: PhantomData<&'a L>,
195    }
196
197    impl<'a, L, G> LockGuard<'a, L, G> {
198        /// Returns a new lock guard.
199        #[cfg_attr(feature = "recursive-lock-panic", track_caller)]
200        pub fn new<F: FnOnce(&'a L) -> G>(lock: &'a L, lock_fn: F) -> Self {
201            #[cfg(feature = "recursive-lock-panic")]
202            let lock_tracker = LockTracker::new(lock);
203
204            Self {
205                guard: lock_fn(lock),
206
207                #[cfg(feature = "recursive-lock-panic")]
208                _lock_tracker: lock_tracker,
209                #[cfg(not(feature = "recursive-lock-panic"))]
210                _marker: PhantomData,
211            }
212        }
213    }
214
215    impl<L, G: Deref> Deref for LockGuard<'_, L, G> {
216        type Target = G::Target;
217
218        fn deref(&self) -> &G::Target {
219            self.guard.deref()
220        }
221    }
222
223    impl<L, G: DerefMut> DerefMut for LockGuard<'_, L, G> {
224        fn deref_mut(&mut self) -> &mut G::Target {
225            self.guard.deref_mut()
226        }
227    }
228}
229
230#[cfg(feature = "recursive-lock-panic")]
231mod lock_tracker {
232    use core::cell::RefCell;
233    use core::panic::Location;
234    use std::collections::HashMap;
235
236    std::thread_local! {
237        static HELD_LOCKS: RefCell<HashMap<*const usize, &'static Location<'static>>> =
238            RefCell::new(HashMap::new());
239    }
240
241    /// An RAII object to keep track of a lock that is (or soon to be) held.
242    ///
243    /// The `Drop` implementation of this struct removes the lock from the
244    /// thread-local table of held locks.
245    pub(crate) struct LockTracker<'a, L>(&'a L);
246
247    impl<'a, L> LockTracker<'a, L> {
248        /// Tracks that the lock is to be held.
249        ///
250        /// This method adds the lock to the thread-local table of held locks.
251        ///
252        /// # Panics
253        ///
254        /// Panics if the lock is already held by the calling thread.
255        #[track_caller]
256        pub(crate) fn new(lock: &'a L) -> Self {
257            {
258                let ptr = lock as *const _ as *const _;
259                match HELD_LOCKS.with(|l| l.borrow_mut().insert(ptr, Location::caller())) {
260                    None => {}
261                    Some(prev_lock_caller) => {
262                        panic!("lock already held; ptr = {:p}\n{}", ptr, prev_lock_caller)
263                    }
264                }
265            }
266
267            Self(lock)
268        }
269    }
270
271    impl<L> Drop for LockTracker<'_, L> {
272        fn drop(&mut self) {
273            let Self(lock) = self;
274            let ptr = *lock as *const _ as *const _;
275            assert_ne!(
276                HELD_LOCKS.with(|l| l.borrow_mut().remove(&ptr)),
277                None,
278                "must have previously been locked; ptr = {:p}",
279                ptr
280            );
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    use std::thread;
290
291    #[test]
292    fn mutex_lock_and_write() {
293        let m = Mutex::<u32>::new(0);
294        {
295            let mut guard = m.lock();
296            assert_eq!(*guard, 0);
297            *guard = 5;
298        }
299
300        {
301            let guard = m.lock();
302            assert_eq!(*guard, 5);
303        }
304    }
305
306    #[test]
307    fn mutex_lock_from_different_threads() {
308        const NUM_THREADS: u32 = 4;
309
310        let m = Mutex::<u32>::new(u32::MAX);
311        let m = &m;
312
313        thread::scope(|s| {
314            for i in 0..NUM_THREADS {
315                let _: thread::ScopedJoinHandle<'_, _> = s.spawn(move || {
316                    let prev = {
317                        let mut guard = m.lock();
318                        let prev = *guard;
319                        *guard = i;
320                        prev
321                    };
322
323                    assert!(prev == u32::MAX || prev < NUM_THREADS);
324                });
325            }
326        });
327
328        let guard = m.lock();
329        assert!(*guard < NUM_THREADS);
330    }
331
332    #[test]
333    #[should_panic(expected = "lock already held")]
334    #[cfg(feature = "recursive-lock-panic")]
335    fn mutex_double_lock_panic() {
336        let m = Mutex::<u32>::new(0);
337        let _ok_guard = m.lock();
338        let _panic_guard = m.lock();
339    }
340
341    #[test]
342    fn rwlock_read_lock() {
343        let rw = RwLock::<u32>::new(0);
344
345        {
346            let guard = rw.read();
347            assert_eq!(*guard, 0);
348        }
349
350        {
351            let guard = rw.read();
352            assert_eq!(*guard, 0);
353        }
354    }
355
356    #[test]
357    fn rwlock_write_lock() {
358        let rw = RwLock::<u32>::new(0);
359        {
360            let mut guard = rw.write();
361            assert_eq!(*guard, 0);
362            *guard = 5;
363        }
364
365        {
366            let guard = rw.write();
367            assert_eq!(*guard, 5);
368        }
369    }
370
371    #[test]
372    fn rwlock_read_and_write_from_different_threads() {
373        const NUM_THREADS: u32 = 4;
374
375        let rw = RwLock::<u32>::new(u32::MAX);
376        let rw = &rw;
377
378        thread::scope(|s| {
379            for i in 0..NUM_THREADS {
380                let _: thread::ScopedJoinHandle<'_, _> = s.spawn(move || {
381                    let prev = if i % 2 == 0 {
382                        // Only threads with even numbered `i` performs a write.
383                        let mut guard = rw.write();
384                        let prev = *guard;
385                        *guard = i;
386                        prev
387                    } else {
388                        let guard = rw.read();
389                        *guard
390                    };
391
392                    assert!(prev == u32::MAX || (prev < NUM_THREADS && prev % 2 == 0));
393                });
394            }
395        });
396
397        let val = *rw.read();
398        assert!(val < NUM_THREADS && val % 2 == 0);
399    }
400
401    #[test]
402    #[cfg_attr(feature = "recursive-lock-panic", should_panic(expected = "lock already held"))]
403    fn mutex_double_read() {
404        let rw = RwLock::<u32>::new(0);
405        let ok_guard = rw.read();
406        assert_eq!(*ok_guard, 0);
407        let maybe_panic_guard = rw.read();
408        assert_eq!(*maybe_panic_guard, 0);
409    }
410
411    #[test]
412    #[should_panic(expected = "lock already held")]
413    #[cfg(feature = "recursive-lock-panic")]
414    fn mutex_double_write_panic() {
415        let rw = RwLock::<u32>::new(0);
416        let _ok_guard = rw.write();
417        let _panic_guard = rw.write();
418    }
419
420    #[test]
421    #[should_panic(expected = "lock already held")]
422    #[cfg(feature = "recursive-lock-panic")]
423    fn mutex_double_read_then_write_panic() {
424        let rw = RwLock::<u32>::new(0);
425        let _ok_guard = rw.read();
426        let _panic_guard = rw.write();
427    }
428
429    #[test]
430    #[should_panic(expected = "lock already held")]
431    #[cfg(feature = "recursive-lock-panic")]
432    fn mutex_double_write_then_read_panic() {
433        let rw = RwLock::<u32>::new(0);
434        let _ok_guard = rw.read();
435        let _panic_guard = rw.write();
436    }
437}