starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use {fuchsia_sync as _, lock_api as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use lock_api::RawMutex;
11use std::{any, fmt};
12
13pub use fuchsia_sync::{
14    MappedMutexGuard, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
15};
16
17/// A trait for lock guards that can be temporarily unlocked asynchronously.
18/// This is useful for performing async operations while holding a lock, without
19/// causing deadlocks or holding the lock for an extended period.
20#[async_trait::async_trait(?Send)]
21pub trait AsyncUnlockable {
22    /// Temporarily unlocks the guard `s`, executes the async function `f`, and then
23    /// re-locks the guard.
24    /// The lock is guaranteed to be re-acquired before this function returns.
25    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
26    where
27        F: AsyncFnOnce() -> U;
28}
29
30#[async_trait::async_trait(?Send)]
31impl<'a, T> crate::AsyncUnlockable for MutexGuard<'a, T> {
32    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
33    where
34        F: AsyncFnOnce() -> U,
35    {
36        // SAFETY: The guard always have a lock mutex.
37        unsafe {
38            Self::mutex(s).raw().unlock();
39        }
40        scopeguard::defer!(
41            // SAFETY: The mutex has been unlocked previously.
42            unsafe { Self::mutex(s).raw().lock() }
43        );
44        f().await
45    }
46}
47
48/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
49/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
50/// deadlock.
51pub fn ordered_lock<'a, T>(
52    m1: &'a Mutex<T>,
53    m2: &'a Mutex<T>,
54) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
55    let ptr1: *const Mutex<T> = m1;
56    let ptr2: *const Mutex<T> = m2;
57    if ptr1 < ptr2 {
58        let g1 = m1.lock();
59        let g2 = m2.lock();
60        (g1, g2)
61    } else {
62        let g2 = m2.lock();
63        let g1 = m1.lock();
64        (g1, g2)
65    }
66}
67
68/// Acquires multiple mutexes in a consistent order based on their memory addresses.
69/// This helps prevent deadlocks.
70pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
71    // Create a vector of tuples containing the mutex and its original index.
72    let mut indexed_mutexes =
73        mutexes.into_iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
74
75    // Sort the indexed mutexes by their memory addresses.
76    indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
77
78    // Acquire the locks in the sorted order.
79    let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
80
81    // Reorder the guards to match the original order of the mutexes.
82    guards.sort_by_key(|(i, _)| *i);
83
84    guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
85}
86
87/// A wrapper for mutex that requires a `Locked` context to acquire.
88/// This context must be of a level that precedes `L` in the lock ordering graph
89/// where `L` is a level associated with this mutex.
90pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
91    mutex: Mutex<T>,
92    _phantom: PhantomData<L>,
93}
94
95impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
96    fn default() -> Self {
97        Self { mutex: Default::default(), _phantom: Default::default() }
98    }
99}
100
101impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
104    }
105}
106
107impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
108    type Data = T;
109    type Guard<'a>
110        = MutexGuard<'a, T>
111    where
112        T: 'a,
113        L: 'a;
114    fn lock(&self) -> Self::Guard<'_> {
115        self.mutex.lock()
116    }
117}
118
119impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
120    pub const fn new(t: T) -> Self {
121        Self { mutex: Mutex::new(t), _phantom: PhantomData }
122    }
123
124    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
125    where
126        P: LockBefore<L>,
127    {
128        locked.lock(self)
129    }
130
131    pub fn lock_and<'a, P>(
132        &'a self,
133        locked: &'a mut Locked<P>,
134    ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
135    where
136        P: LockBefore<L>,
137    {
138        locked.lock_and(self)
139    }
140}
141
142/// Lock two OrderedMutex of the same level in the consistent order. Returns both
143/// guards and a new locked context.
144pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
145    locked: &'a mut Locked<P>,
146    m1: &'a OrderedMutex<T, L>,
147    m2: &'a OrderedMutex<T, L>,
148) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, &'a mut Locked<L>)
149where
150    P: LockBefore<L>,
151{
152    locked.lock_both_and(m1, m2)
153}
154
155/// A wrapper for an RwLock that requires a `Locked` context to acquire.
156/// This context must be of a level that precedes `L` in the lock ordering graph
157/// where `L` is a level associated with this RwLock.
158pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
159    rwlock: RwLock<T>,
160    _phantom: PhantomData<L>,
161}
162
163impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
164    fn default() -> Self {
165        Self { rwlock: Default::default(), _phantom: Default::default() }
166    }
167}
168
169impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
172    }
173}
174
175impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
176    type Data = T;
177    type ReadGuard<'a>
178        = RwLockReadGuard<'a, T>
179    where
180        T: 'a,
181        L: 'a;
182    type WriteGuard<'a>
183        = RwLockWriteGuard<'a, T>
184    where
185        T: 'a,
186        L: 'a;
187    fn read_lock(&self) -> Self::ReadGuard<'_> {
188        self.rwlock.read()
189    }
190    fn write_lock(&self) -> Self::WriteGuard<'_> {
191        self.rwlock.write()
192    }
193}
194
195impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
196    pub const fn new(t: T) -> Self {
197        Self { rwlock: RwLock::new(t), _phantom: PhantomData }
198    }
199
200    pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
201    where
202        P: LockBefore<L>,
203    {
204        locked.read_lock(self)
205    }
206
207    pub fn write<'a, P>(
208        &'a self,
209        locked: &'a mut Locked<P>,
210    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
211    where
212        P: LockBefore<L>,
213    {
214        locked.write_lock(self)
215    }
216
217    pub fn read_and<'a, P>(
218        &'a self,
219        locked: &'a mut Locked<P>,
220    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
221    where
222        P: LockBefore<L>,
223    {
224        locked.read_lock_and(self)
225    }
226
227    pub fn write_and<'a, P>(
228        &'a self,
229        locked: &'a mut Locked<P>,
230    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
231    where
232        P: LockBefore<L>,
233    {
234        locked.write_lock_and(self)
235    }
236}
237
238#[cfg(test)]
239mod test {
240    use super::*;
241    use crate::Unlocked;
242
243    #[::fuchsia::test]
244    fn test_lock_ordering() {
245        let l1 = Mutex::new(1);
246        let l2 = Mutex::new(2);
247
248        {
249            let (g1, g2) = ordered_lock(&l1, &l2);
250            assert_eq!(*g1, 1);
251            assert_eq!(*g2, 2);
252        }
253        {
254            let (g2, g1) = ordered_lock(&l2, &l1);
255            assert_eq!(*g1, 1);
256            assert_eq!(*g2, 2);
257        }
258    }
259
260    #[::fuchsia::test]
261    fn test_vec_lock_ordering() {
262        let l1 = Mutex::new(1);
263        let l0 = Mutex::new(0);
264        let l2 = Mutex::new(2);
265
266        {
267            let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
268            assert_eq!(*guards[0], 0);
269            assert_eq!(*guards[1], 1);
270            assert_eq!(*guards[2], 2);
271        }
272        {
273            let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
274            assert_eq!(*guards[0], 2);
275            assert_eq!(*guards[1], 1);
276            assert_eq!(*guards[2], 0);
277        }
278    }
279
280    mod lock_levels {
281        //! Lock ordering tree:
282        //! Unlocked -> A -> B -> C
283        //!          -> D -> E -> F
284        use crate::{LockAfter, UninterruptibleLock, Unlocked};
285        use lock_ordering_macro::lock_ordering;
286        lock_ordering! {
287            Unlocked => A,
288            A => B,
289            B => C,
290            Unlocked => D,
291            D => E,
292            E => F,
293        }
294
295        impl LockAfter<UninterruptibleLock> for A {}
296        impl LockAfter<UninterruptibleLock> for B {}
297        impl LockAfter<UninterruptibleLock> for C {}
298        impl LockAfter<UninterruptibleLock> for D {}
299        impl LockAfter<UninterruptibleLock> for E {}
300        impl LockAfter<UninterruptibleLock> for F {}
301    }
302
303    use lock_levels::{A, B, C, D, E, F};
304
305    #[test]
306    fn test_ordered_mutex() {
307        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
308        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
309        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
310
311        #[allow(
312            clippy::undocumented_unsafe_blocks,
313            reason = "Force documented unsafe blocks in Starnix"
314        )]
315        let locked = unsafe { Unlocked::new() };
316
317        let (a_data, mut next_locked) = a.lock_and(locked);
318        let c_data = c.lock(&mut next_locked);
319
320        // This won't compile
321        //let _b_data = _b.lock(locked);
322        //let _b_data = _b.lock(&mut next_locked);
323
324        assert_eq!(&*a_data, &15);
325        assert_eq!(&*c_data, &45);
326    }
327    #[test]
328    fn test_ordered_rwlock() {
329        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
330        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
331        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
332
333        #[allow(
334            clippy::undocumented_unsafe_blocks,
335            reason = "Force documented unsafe blocks in Starnix"
336        )]
337        let locked = unsafe { Unlocked::new() };
338        {
339            let (d_data, mut next_locked) = d.write_and(locked);
340            let f_data = f.read(&mut next_locked);
341
342            // This won't compile
343            //let _e_data = _e.read(locked);
344            //let _e_data = _e.read(&mut next_locked);
345
346            assert_eq!(&*d_data, &15);
347            assert_eq!(&*f_data, &45);
348        }
349        {
350            let (d_data, mut next_locked) = d.read_and(locked);
351            let f_data = f.write(&mut next_locked);
352
353            // This won't compile
354            //let _e_data = _e.write(locked);
355            //let _e_data = _e.write(&mut next_locked);
356
357            assert_eq!(&*d_data, &15);
358            assert_eq!(&*f_data, &45);
359        }
360    }
361
362    #[test]
363    fn test_lock_both() {
364        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
365        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
366        #[allow(
367            clippy::undocumented_unsafe_blocks,
368            reason = "Force documented unsafe blocks in Starnix"
369        )]
370        let locked = unsafe { Unlocked::new() };
371        {
372            let (a1_data, a2_data, _) = lock_both(locked, &a1, &a2);
373            assert_eq!(&*a1_data, &15);
374            assert_eq!(&*a2_data, &30);
375        }
376        {
377            let (a2_data, a1_data, _) = lock_both(locked, &a2, &a1);
378            assert_eq!(&*a1_data, &15);
379            assert_eq!(&*a2_data, &30);
380        }
381    }
382}