Skip to main content

starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use fuchsia_sync as _;
7use lock_api as _;
8
9use crate::{
10    LockAfter, LockBefore, LockDepMutex, LockDepRwLock, LockFor, Locked, RwLockFor,
11    UninterruptibleLock,
12};
13
14use lock_api::RawMutex;
15use std::{any, fmt};
16
17pub use fuchsia_sync::{
18    MappedMutexGuard, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
19};
20
21/// A trait for lock guards that can be temporarily unlocked asynchronously.
22/// This is useful for performing async operations while holding a lock, without
23/// causing deadlocks or holding the lock for an extended period.
24#[async_trait::async_trait(?Send)]
25pub trait AsyncUnlockable {
26    /// Temporarily unlocks the guard `s`, executes the async function `f`, and then
27    /// re-locks the guard.
28    /// The lock is guaranteed to be re-acquired before this function returns.
29    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
30    where
31        F: AsyncFnOnce() -> U;
32}
33
34#[async_trait::async_trait(?Send)]
35impl<'a, T> crate::AsyncUnlockable for MutexGuard<'a, T> {
36    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
37    where
38        F: AsyncFnOnce() -> U,
39    {
40        // SAFETY: The guard always have a lock mutex.
41        unsafe {
42            Self::mutex(s).raw().unlock();
43        }
44        scopeguard::defer!(
45            // SAFETY: The mutex has been unlocked previously.
46            unsafe { Self::mutex(s).raw().lock() }
47        );
48        f().await
49    }
50}
51
52/// A generic mutex for the ordered_lock operations.
53pub trait MutexLike {
54    type Guard<'a>
55    where
56        Self: 'a;
57    type Context;
58
59    fn context() -> Self::Context;
60
61    /// Lock the mutex. `level` is the index of the locked mutex in the lock ordering.
62    fn lock(&self, context: &mut Self::Context) -> Self::Guard<'_>;
63}
64
65impl<T> MutexLike for Mutex<T> {
66    type Guard<'a>
67        = MutexGuard<'a, T>
68    where
69        T: 'a;
70    type Context = ();
71
72    #[inline(always)]
73    fn context() -> Self::Context {
74        ()
75    }
76
77    #[inline(always)]
78    fn lock(&self, _context: &mut Self::Context) -> Self::Guard<'_> {
79        return self.lock();
80    }
81}
82
83/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
84/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
85/// deadlock.
86pub fn ordered_lock<'a, M: MutexLike>(m1: &'a M, m2: &'a M) -> (M::Guard<'a>, M::Guard<'a>) {
87    let mut context = M::context();
88    let ptr1: *const M = m1;
89    let ptr2: *const M = m2;
90    if ptr1 < ptr2 {
91        let g1 = m1.lock(&mut context);
92        let g2 = m2.lock(&mut context);
93        (g1, g2)
94    } else {
95        let g2 = m2.lock(&mut context);
96        let g1 = m1.lock(&mut context);
97        (g1, g2)
98    }
99}
100
101/// Acquires multiple mutexes in a consistent order based on their memory addresses.
102/// This helps prevent deadlocks.
103pub fn ordered_lock_vec<'a, M: MutexLike>(mutexes: &[&'a M]) -> Vec<M::Guard<'a>> {
104    let mut context = M::context();
105
106    // Create a vector of tuples containing the mutex and its original index.
107    let mut indexed_mutexes = mutexes.iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
108
109    // Sort the indexed mutexes by their memory addresses.
110    indexed_mutexes.sort_by_key(|(_, m)| *m as *const M);
111
112    // Acquire the locks in the sorted order.
113    let mut guards =
114        indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock(&mut context))).collect::<Vec<_>>();
115
116    // Reorder the guards to match the original order of the mutexes.
117    guards.sort_by_key(|(i, _)| *i);
118
119    guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
120}
121
122/// A wrapper for mutex that requires a `Locked` context to acquire.
123/// This context must be of a level that precedes `L` in the lock ordering graph
124/// where `L` is a level associated with this mutex.
125pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock> + crate::LockLevel> {
126    mutex: LockDepMutex<T, L>,
127}
128
129impl<T: Default, L: LockAfter<UninterruptibleLock> + crate::LockLevel> Default
130    for OrderedMutex<T, L>
131{
132    fn default() -> Self {
133        Self { mutex: LockDepMutex::new(T::default()) }
134    }
135}
136
137impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock> + crate::LockLevel> fmt::Debug
138    for OrderedMutex<T, L>
139{
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
142    }
143}
144
145impl<T, L: LockAfter<UninterruptibleLock> + crate::LockLevel> LockFor<L> for OrderedMutex<T, L> {
146    type Data = T;
147    type Guard<'a>
148        = crate::LockDepGuard<'a, T>
149    where
150        T: 'a,
151        L: 'a;
152    fn lock(&self) -> Self::Guard<'_> {
153        self.mutex.lock()
154    }
155}
156
157impl<T, L: LockAfter<UninterruptibleLock> + crate::LockLevel> OrderedMutex<T, L> {
158    pub const fn new(t: T) -> Self {
159        Self { mutex: LockDepMutex::new(t) }
160    }
161
162    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
163    where
164        P: LockBefore<L>,
165    {
166        locked.lock(self)
167    }
168
169    pub fn lock_and<'a, P>(
170        &'a self,
171        locked: &'a mut Locked<P>,
172    ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
173    where
174        P: LockBefore<L>,
175    {
176        locked.lock_and(self)
177    }
178}
179
180/// Lock two OrderedMutex of the same level in the consistent order. Returns both
181/// guards and a new locked context.
182pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock> + crate::LockLevel, P>(
183    locked: &'a mut Locked<P>,
184    m1: &'a OrderedMutex<T, L>,
185    m2: &'a OrderedMutex<T, L>,
186) -> (crate::LockDepGuard<'a, T>, crate::LockDepGuard<'a, T>, &'a mut Locked<L>)
187where
188    P: LockBefore<L>,
189{
190    locked.lock_both_and(m1, m2)
191}
192
193/// A wrapper for an RwLock that requires a `Locked` context to acquire.
194/// This context must be of a level that precedes `L` in the lock ordering graph
195/// where `L` is a level associated with this RwLock.
196pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock> + crate::LockLevel> {
197    rwlock: LockDepRwLock<T, L>,
198}
199
200impl<T: Default, L: LockAfter<UninterruptibleLock> + crate::LockLevel> Default
201    for OrderedRwLock<T, L>
202{
203    fn default() -> Self {
204        Self { rwlock: LockDepRwLock::new(T::default()) }
205    }
206}
207
208impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock> + crate::LockLevel> fmt::Debug
209    for OrderedRwLock<T, L>
210{
211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
213    }
214}
215
216impl<T, L: LockAfter<UninterruptibleLock> + crate::LockLevel> RwLockFor<L> for OrderedRwLock<T, L> {
217    type Data = T;
218    type ReadGuard<'a>
219        = crate::LockDepReadGuard<'a, T>
220    where
221        T: 'a,
222        L: 'a;
223    type WriteGuard<'a>
224        = crate::LockDepWriteGuard<'a, T>
225    where
226        T: 'a,
227        L: 'a;
228    fn read_lock(&self) -> Self::ReadGuard<'_> {
229        self.rwlock.read()
230    }
231    fn write_lock(&self) -> Self::WriteGuard<'_> {
232        self.rwlock.write()
233    }
234}
235
236impl<T, L: LockAfter<UninterruptibleLock> + crate::LockLevel> OrderedRwLock<T, L> {
237    pub const fn new(t: T) -> Self {
238        Self { rwlock: LockDepRwLock::new(t) }
239    }
240
241    pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
242    where
243        P: LockBefore<L>,
244    {
245        locked.read_lock(self)
246    }
247
248    pub fn write<'a, P>(
249        &'a self,
250        locked: &'a mut Locked<P>,
251    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
252    where
253        P: LockBefore<L>,
254    {
255        locked.write_lock(self)
256    }
257
258    pub fn read_and<'a, P>(
259        &'a self,
260        locked: &'a mut Locked<P>,
261    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
262    where
263        P: LockBefore<L>,
264    {
265        locked.read_lock_and(self)
266    }
267
268    pub fn write_and<'a, P>(
269        &'a self,
270        locked: &'a mut Locked<P>,
271    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
272    where
273        P: LockBefore<L>,
274    {
275        locked.write_lock_and(self)
276    }
277}
278
279#[cfg(test)]
280mod test {
281    use super::*;
282    use crate::Unlocked;
283
284    #[::fuchsia::test]
285    fn test_lock_ordering() {
286        let l1 = Mutex::new(1);
287        let l2 = Mutex::new(2);
288
289        {
290            let (g1, g2) = ordered_lock(&l1, &l2);
291            assert_eq!(*g1, 1);
292            assert_eq!(*g2, 2);
293        }
294        {
295            let (g2, g1) = ordered_lock(&l2, &l1);
296            assert_eq!(*g1, 1);
297            assert_eq!(*g2, 2);
298        }
299    }
300
301    #[::fuchsia::test]
302    fn test_vec_lock_ordering() {
303        let l1 = Mutex::new(1);
304        let l0 = Mutex::new(0);
305        let l2 = Mutex::new(2);
306
307        {
308            let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
309            assert_eq!(*guards[0], 0);
310            assert_eq!(*guards[1], 1);
311            assert_eq!(*guards[2], 2);
312        }
313        {
314            let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
315            assert_eq!(*guards[0], 2);
316            assert_eq!(*guards[1], 1);
317            assert_eq!(*guards[2], 0);
318        }
319    }
320
321    mod lock_levels {
322        //! Lock ordering tree:
323        //! Unlocked -> A -> B -> C
324        //!          -> D -> E -> F
325        use crate::{LockAfter, UninterruptibleLock, Unlocked};
326        use lock_ordering_macro::lock_ordering;
327        lock_ordering! {
328            Unlocked => A,
329            A => B,
330            B => C,
331            Unlocked => D,
332            D => E,
333            E => F,
334        }
335
336        impl LockAfter<UninterruptibleLock> for A {}
337        impl LockAfter<UninterruptibleLock> for B {}
338        impl LockAfter<UninterruptibleLock> for C {}
339        impl LockAfter<UninterruptibleLock> for D {}
340        impl LockAfter<UninterruptibleLock> for E {}
341        impl LockAfter<UninterruptibleLock> for F {}
342    }
343
344    use lock_levels::{A, B, C, D, E, F};
345
346    #[test]
347    fn test_ordered_mutex() {
348        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
349        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
350        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
351
352        #[allow(
353            clippy::undocumented_unsafe_blocks,
354            reason = "Force documented unsafe blocks in Starnix"
355        )]
356        let locked = unsafe { Unlocked::new() };
357
358        let (a_data, mut next_locked) = a.lock_and(locked);
359        let c_data = c.lock(&mut next_locked);
360
361        // This won't compile
362        //let _b_data = _b.lock(locked);
363        //let _b_data = _b.lock(&mut next_locked);
364
365        assert_eq!(&*a_data, &15);
366        assert_eq!(&*c_data, &45);
367    }
368    #[test]
369    fn test_ordered_rwlock() {
370        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
371        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
372        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
373
374        #[allow(
375            clippy::undocumented_unsafe_blocks,
376            reason = "Force documented unsafe blocks in Starnix"
377        )]
378        let locked = unsafe { Unlocked::new() };
379        {
380            let (d_data, mut next_locked) = d.write_and(locked);
381            let f_data = f.read(&mut next_locked);
382
383            // This won't compile
384            //let _e_data = _e.read(locked);
385            //let _e_data = _e.read(&mut next_locked);
386
387            assert_eq!(&*d_data, &15);
388            assert_eq!(&*f_data, &45);
389        }
390        {
391            let (d_data, mut next_locked) = d.read_and(locked);
392            let f_data = f.write(&mut next_locked);
393
394            // This won't compile
395            //let _e_data = _e.write(locked);
396            //let _e_data = _e.write(&mut next_locked);
397
398            assert_eq!(&*d_data, &15);
399            assert_eq!(&*f_data, &45);
400        }
401    }
402
403    #[test]
404    fn test_lock_both() {
405        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
406        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
407        #[allow(
408            clippy::undocumented_unsafe_blocks,
409            reason = "Force documented unsafe blocks in Starnix"
410        )]
411        let locked = unsafe { Unlocked::new() };
412        {
413            let (a1_data, a2_data, _) = lock_both(locked, &a1, &a2);
414            assert_eq!(&*a1_data, &15);
415            assert_eq!(&*a2_data, &30);
416        }
417        {
418            let (a2_data, a1_data, _) = lock_both(locked, &a2, &a1);
419            assert_eq!(&*a1_data, &15);
420            assert_eq!(&*a2_data, &30);
421        }
422    }
423}