Skip to main content

starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use {fuchsia_sync as _, lock_api as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use lock_api::RawMutex;
11use std::{any, fmt};
12
13pub use fuchsia_sync::{
14    MappedMutexGuard, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
15};
16
17/// A trait for lock guards that can be temporarily unlocked asynchronously.
18/// This is useful for performing async operations while holding a lock, without
19/// causing deadlocks or holding the lock for an extended period.
20#[async_trait::async_trait(?Send)]
21pub trait AsyncUnlockable {
22    /// Temporarily unlocks the guard `s`, executes the async function `f`, and then
23    /// re-locks the guard.
24    /// The lock is guaranteed to be re-acquired before this function returns.
25    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
26    where
27        F: AsyncFnOnce() -> U;
28}
29
30#[async_trait::async_trait(?Send)]
31impl<'a, T> crate::AsyncUnlockable for MutexGuard<'a, T> {
32    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
33    where
34        F: AsyncFnOnce() -> U,
35    {
36        // SAFETY: The guard always have a lock mutex.
37        unsafe {
38            Self::mutex(s).raw().unlock();
39        }
40        scopeguard::defer!(
41            // SAFETY: The mutex has been unlocked previously.
42            unsafe { Self::mutex(s).raw().lock() }
43        );
44        f().await
45    }
46}
47
48/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
49/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
50/// deadlock.
51pub fn ordered_lock<'a, T>(
52    m1: &'a Mutex<T>,
53    m2: &'a Mutex<T>,
54) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
55    let ptr1: *const Mutex<T> = m1;
56    let ptr2: *const Mutex<T> = m2;
57    if ptr1 < ptr2 {
58        let g1 = m1.lock();
59        let g2 = m2.lock();
60        (g1, g2)
61    } else {
62        let g2 = m2.lock();
63        let g1 = m1.lock();
64        (g1, g2)
65    }
66}
67
68/// Acquires multiple mutexes in a consistent order based on their memory addresses.
69/// This helps prevent deadlocks.
70pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
71    // Create a vector of tuples containing the mutex and its original index.
72    let mut indexed_mutexes = mutexes.iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
73
74    // Sort the indexed mutexes by their memory addresses.
75    indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
76
77    // Acquire the locks in the sorted order.
78    let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
79
80    // Reorder the guards to match the original order of the mutexes.
81    guards.sort_by_key(|(i, _)| *i);
82
83    guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
84}
85
86/// A wrapper for mutex that requires a `Locked` context to acquire.
87/// This context must be of a level that precedes `L` in the lock ordering graph
88/// where `L` is a level associated with this mutex.
89pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
90    mutex: Mutex<T>,
91    _phantom: PhantomData<L>,
92}
93
94impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
95    fn default() -> Self {
96        Self { mutex: Default::default(), _phantom: Default::default() }
97    }
98}
99
100impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
103    }
104}
105
106impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
107    type Data = T;
108    type Guard<'a>
109        = MutexGuard<'a, T>
110    where
111        T: 'a,
112        L: 'a;
113    fn lock(&self) -> Self::Guard<'_> {
114        self.mutex.lock()
115    }
116}
117
118impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
119    pub const fn new(t: T) -> Self {
120        Self { mutex: Mutex::new(t), _phantom: PhantomData }
121    }
122
123    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
124    where
125        P: LockBefore<L>,
126    {
127        locked.lock(self)
128    }
129
130    pub fn lock_and<'a, P>(
131        &'a self,
132        locked: &'a mut Locked<P>,
133    ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
134    where
135        P: LockBefore<L>,
136    {
137        locked.lock_and(self)
138    }
139}
140
141/// Lock two OrderedMutex of the same level in the consistent order. Returns both
142/// guards and a new locked context.
143pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
144    locked: &'a mut Locked<P>,
145    m1: &'a OrderedMutex<T, L>,
146    m2: &'a OrderedMutex<T, L>,
147) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, &'a mut Locked<L>)
148where
149    P: LockBefore<L>,
150{
151    locked.lock_both_and(m1, m2)
152}
153
154/// A wrapper for an RwLock that requires a `Locked` context to acquire.
155/// This context must be of a level that precedes `L` in the lock ordering graph
156/// where `L` is a level associated with this RwLock.
157pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
158    rwlock: RwLock<T>,
159    _phantom: PhantomData<L>,
160}
161
162impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
163    fn default() -> Self {
164        Self { rwlock: Default::default(), _phantom: Default::default() }
165    }
166}
167
168impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
171    }
172}
173
174impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
175    type Data = T;
176    type ReadGuard<'a>
177        = RwLockReadGuard<'a, T>
178    where
179        T: 'a,
180        L: 'a;
181    type WriteGuard<'a>
182        = RwLockWriteGuard<'a, T>
183    where
184        T: 'a,
185        L: 'a;
186    fn read_lock(&self) -> Self::ReadGuard<'_> {
187        self.rwlock.read()
188    }
189    fn write_lock(&self) -> Self::WriteGuard<'_> {
190        self.rwlock.write()
191    }
192}
193
194impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
195    pub const fn new(t: T) -> Self {
196        Self { rwlock: RwLock::new(t), _phantom: PhantomData }
197    }
198
199    pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
200    where
201        P: LockBefore<L>,
202    {
203        locked.read_lock(self)
204    }
205
206    pub fn write<'a, P>(
207        &'a self,
208        locked: &'a mut Locked<P>,
209    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
210    where
211        P: LockBefore<L>,
212    {
213        locked.write_lock(self)
214    }
215
216    pub fn read_and<'a, P>(
217        &'a self,
218        locked: &'a mut Locked<P>,
219    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
220    where
221        P: LockBefore<L>,
222    {
223        locked.read_lock_and(self)
224    }
225
226    pub fn write_and<'a, P>(
227        &'a self,
228        locked: &'a mut Locked<P>,
229    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
230    where
231        P: LockBefore<L>,
232    {
233        locked.write_lock_and(self)
234    }
235}
236
237#[cfg(test)]
238mod test {
239    use super::*;
240    use crate::Unlocked;
241
242    #[::fuchsia::test]
243    fn test_lock_ordering() {
244        let l1 = Mutex::new(1);
245        let l2 = Mutex::new(2);
246
247        {
248            let (g1, g2) = ordered_lock(&l1, &l2);
249            assert_eq!(*g1, 1);
250            assert_eq!(*g2, 2);
251        }
252        {
253            let (g2, g1) = ordered_lock(&l2, &l1);
254            assert_eq!(*g1, 1);
255            assert_eq!(*g2, 2);
256        }
257    }
258
259    #[::fuchsia::test]
260    fn test_vec_lock_ordering() {
261        let l1 = Mutex::new(1);
262        let l0 = Mutex::new(0);
263        let l2 = Mutex::new(2);
264
265        {
266            let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
267            assert_eq!(*guards[0], 0);
268            assert_eq!(*guards[1], 1);
269            assert_eq!(*guards[2], 2);
270        }
271        {
272            let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
273            assert_eq!(*guards[0], 2);
274            assert_eq!(*guards[1], 1);
275            assert_eq!(*guards[2], 0);
276        }
277    }
278
279    mod lock_levels {
280        //! Lock ordering tree:
281        //! Unlocked -> A -> B -> C
282        //!          -> D -> E -> F
283        use crate::{LockAfter, UninterruptibleLock, Unlocked};
284        use lock_ordering_macro::lock_ordering;
285        lock_ordering! {
286            Unlocked => A,
287            A => B,
288            B => C,
289            Unlocked => D,
290            D => E,
291            E => F,
292        }
293
294        impl LockAfter<UninterruptibleLock> for A {}
295        impl LockAfter<UninterruptibleLock> for B {}
296        impl LockAfter<UninterruptibleLock> for C {}
297        impl LockAfter<UninterruptibleLock> for D {}
298        impl LockAfter<UninterruptibleLock> for E {}
299        impl LockAfter<UninterruptibleLock> for F {}
300    }
301
302    use lock_levels::{A, B, C, D, E, F};
303
304    #[test]
305    fn test_ordered_mutex() {
306        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
307        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
308        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
309
310        #[allow(
311            clippy::undocumented_unsafe_blocks,
312            reason = "Force documented unsafe blocks in Starnix"
313        )]
314        let locked = unsafe { Unlocked::new() };
315
316        let (a_data, mut next_locked) = a.lock_and(locked);
317        let c_data = c.lock(&mut next_locked);
318
319        // This won't compile
320        //let _b_data = _b.lock(locked);
321        //let _b_data = _b.lock(&mut next_locked);
322
323        assert_eq!(&*a_data, &15);
324        assert_eq!(&*c_data, &45);
325    }
326    #[test]
327    fn test_ordered_rwlock() {
328        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
329        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
330        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
331
332        #[allow(
333            clippy::undocumented_unsafe_blocks,
334            reason = "Force documented unsafe blocks in Starnix"
335        )]
336        let locked = unsafe { Unlocked::new() };
337        {
338            let (d_data, mut next_locked) = d.write_and(locked);
339            let f_data = f.read(&mut next_locked);
340
341            // This won't compile
342            //let _e_data = _e.read(locked);
343            //let _e_data = _e.read(&mut next_locked);
344
345            assert_eq!(&*d_data, &15);
346            assert_eq!(&*f_data, &45);
347        }
348        {
349            let (d_data, mut next_locked) = d.read_and(locked);
350            let f_data = f.write(&mut next_locked);
351
352            // This won't compile
353            //let _e_data = _e.write(locked);
354            //let _e_data = _e.write(&mut next_locked);
355
356            assert_eq!(&*d_data, &15);
357            assert_eq!(&*f_data, &45);
358        }
359    }
360
361    #[test]
362    fn test_lock_both() {
363        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
364        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
365        #[allow(
366            clippy::undocumented_unsafe_blocks,
367            reason = "Force documented unsafe blocks in Starnix"
368        )]
369        let locked = unsafe { Unlocked::new() };
370        {
371            let (a1_data, a2_data, _) = lock_both(locked, &a1, &a2);
372            assert_eq!(&*a1_data, &15);
373            assert_eq!(&*a2_data, &30);
374        }
375        {
376            let (a2_data, a1_data, _) = lock_both(locked, &a2, &a1);
377            assert_eq!(&*a1_data, &15);
378            assert_eq!(&*a2_data, &30);
379        }
380    }
381}