starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use {fuchsia_sync as _, lock_api as _, tracing_mutex as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use std::{any, fmt};
11
12#[cfg(not(debug_assertions))]
13pub type Mutex<T> = fuchsia_sync::Mutex<T>;
14#[cfg(not(debug_assertions))]
15pub type MutexGuard<'a, T> = fuchsia_sync::MutexGuard<'a, T>;
16#[allow(unused)]
17#[cfg(not(debug_assertions))]
18pub type MappedMutexGuard<'a, T> = fuchsia_sync::MappedMutexGuard<'a, T>;
19
20#[cfg(not(debug_assertions))]
21pub type RwLock<T> = fuchsia_sync::RwLock<T>;
22#[cfg(not(debug_assertions))]
23pub type RwLockReadGuard<'a, T> = fuchsia_sync::RwLockReadGuard<'a, T>;
24#[cfg(not(debug_assertions))]
25pub type RwLockWriteGuard<'a, T> = fuchsia_sync::RwLockWriteGuard<'a, T>;
26
27#[cfg(debug_assertions)]
28type RawTracingMutex = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncMutex>;
29#[cfg(debug_assertions)]
30pub type Mutex<T> = lock_api::Mutex<RawTracingMutex, T>;
31#[cfg(debug_assertions)]
32pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawTracingMutex, T>;
33#[allow(unused)]
34#[cfg(debug_assertions)]
35pub type MappedMutexGuard<'a, T> = lock_api::MappedMutexGuard<'a, RawTracingMutex, T>;
36
37#[cfg(debug_assertions)]
38type RawTracingRwLock = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncRwLock>;
39#[cfg(debug_assertions)]
40pub type RwLock<T> = lock_api::RwLock<RawTracingRwLock, T>;
41#[cfg(debug_assertions)]
42pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawTracingRwLock, T>;
43#[cfg(debug_assertions)]
44pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawTracingRwLock, T>;
45
46/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
47/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
48/// deadlock.
49pub fn ordered_lock<'a, T>(
50    m1: &'a Mutex<T>,
51    m2: &'a Mutex<T>,
52) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
53    let ptr1: *const Mutex<T> = m1;
54    let ptr2: *const Mutex<T> = m2;
55    if ptr1 < ptr2 {
56        let g1 = m1.lock();
57        let g2 = m2.lock();
58        (g1, g2)
59    } else {
60        let g2 = m2.lock();
61        let g1 = m1.lock();
62        (g1, g2)
63    }
64}
65
66/// A wrapper for mutex that requires a `Locked` context to acquire.
67/// This context must be of a level that precedes `L` in the lock ordering graph
68/// where `L` is a level associated with this mutex.
69pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
70    mutex: Mutex<T>,
71    _phantom: PhantomData<L>,
72}
73
74impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
75    fn default() -> Self {
76        Self { mutex: Default::default(), _phantom: Default::default() }
77    }
78}
79
80impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
83    }
84}
85
86impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
87    type Data = T;
88    type Guard<'a>
89        = MutexGuard<'a, T>
90    where
91        T: 'a,
92        L: 'a;
93    fn lock(&self) -> Self::Guard<'_> {
94        self.mutex.lock()
95    }
96}
97
98impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
99    pub const fn new(t: T) -> Self {
100        Self { mutex: Mutex::new(t), _phantom: PhantomData }
101    }
102
103    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<'_, P>) -> <Self as LockFor<L>>::Guard<'a>
104    where
105        P: LockBefore<L>,
106    {
107        locked.lock(self)
108    }
109
110    pub fn lock_and<'a, P>(
111        &'a self,
112        locked: &'a mut Locked<'_, P>,
113    ) -> (<Self as LockFor<L>>::Guard<'a>, Locked<'a, L>)
114    where
115        P: LockBefore<L>,
116    {
117        locked.lock_and(self)
118    }
119}
120
121/// Lock two OrderedMutex of the same level in the consistent order. Returns both
122/// guards and a new locked context.
123pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
124    locked: &'a mut Locked<'_, P>,
125    m1: &'a OrderedMutex<T, L>,
126    m2: &'a OrderedMutex<T, L>,
127) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, Locked<'a, L>)
128where
129    P: LockBefore<L>,
130{
131    locked.lock_both_and(m1, m2)
132}
133
134/// A wrapper for an RwLock that requires a `Locked` context to acquire.
135/// This context must be of a level that precedes `L` in the lock ordering graph
136/// where `L` is a level associated with this RwLock.
137pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
138    rwlock: RwLock<T>,
139    _phantom: PhantomData<L>,
140}
141
142impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
143    fn default() -> Self {
144        Self { rwlock: Default::default(), _phantom: Default::default() }
145    }
146}
147
148impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
151    }
152}
153
154impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
155    type Data = T;
156    type ReadGuard<'a>
157        = RwLockReadGuard<'a, T>
158    where
159        T: 'a,
160        L: 'a;
161    type WriteGuard<'a>
162        = RwLockWriteGuard<'a, T>
163    where
164        T: 'a,
165        L: 'a;
166    fn read_lock(&self) -> Self::ReadGuard<'_> {
167        self.rwlock.read()
168    }
169    fn write_lock(&self) -> Self::WriteGuard<'_> {
170        self.rwlock.write()
171    }
172}
173
174impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
175    pub const fn new(t: T) -> Self {
176        Self { rwlock: RwLock::new(t), _phantom: PhantomData }
177    }
178
179    pub fn read<'a, P>(
180        &'a self,
181        locked: &'a mut Locked<'_, P>,
182    ) -> <Self as RwLockFor<L>>::ReadGuard<'a>
183    where
184        P: LockBefore<L>,
185    {
186        locked.read_lock(self)
187    }
188
189    pub fn write<'a, P>(
190        &'a self,
191        locked: &'a mut Locked<'_, P>,
192    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
193    where
194        P: LockBefore<L>,
195    {
196        locked.write_lock(self)
197    }
198
199    pub fn read_and<'a, P>(
200        &'a self,
201        locked: &'a mut Locked<'_, P>,
202    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, Locked<'a, L>)
203    where
204        P: LockBefore<L>,
205    {
206        locked.read_lock_and(self)
207    }
208
209    pub fn write_and<'a, P>(
210        &'a self,
211        locked: &'a mut Locked<'_, P>,
212    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, Locked<'a, L>)
213    where
214        P: LockBefore<L>,
215    {
216        locked.write_lock_and(self)
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use super::*;
223    use crate::Unlocked;
224
225    #[::fuchsia::test]
226    fn test_lock_ordering() {
227        let l1 = Mutex::new(1);
228        let l2 = Mutex::new(2);
229
230        {
231            let (g1, g2) = ordered_lock(&l1, &l2);
232            assert_eq!(*g1, 1);
233            assert_eq!(*g2, 2);
234        }
235        {
236            let (g2, g1) = ordered_lock(&l2, &l1);
237            assert_eq!(*g1, 1);
238            assert_eq!(*g2, 2);
239        }
240    }
241
242    mod lock_levels {
243        //! Lock ordering tree:
244        //! Unlocked -> A -> B -> C
245        //!          -> D -> E -> F
246        use crate::{LockAfter, UninterruptibleLock, Unlocked};
247        use lock_ordering_macro::lock_ordering;
248        lock_ordering! {
249            Unlocked => A,
250            A => B,
251            B => C,
252            Unlocked => D,
253            D => E,
254            E => F,
255        }
256
257        impl LockAfter<UninterruptibleLock> for A {}
258        impl LockAfter<UninterruptibleLock> for B {}
259        impl LockAfter<UninterruptibleLock> for C {}
260        impl LockAfter<UninterruptibleLock> for D {}
261        impl LockAfter<UninterruptibleLock> for E {}
262        impl LockAfter<UninterruptibleLock> for F {}
263    }
264
265    use lock_levels::{A, B, C, D, E, F};
266
267    #[test]
268    fn test_ordered_mutex() {
269        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
270        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
271        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
272
273        let mut locked = unsafe { Unlocked::new() };
274
275        let (a_data, mut next_locked) = a.lock_and(&mut locked);
276        let c_data = c.lock(&mut next_locked);
277
278        // This won't compile
279        //let _b_data = _b.lock(&mut locked);
280        //let _b_data = _b.lock(&mut next_locked);
281
282        assert_eq!(&*a_data, &15);
283        assert_eq!(&*c_data, &45);
284    }
285    #[test]
286    fn test_ordered_rwlock() {
287        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
288        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
289        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
290
291        let mut locked = unsafe { Unlocked::new() };
292        {
293            let (d_data, mut next_locked) = d.write_and(&mut locked);
294            let f_data = f.read(&mut next_locked);
295
296            // This won't compile
297            //let _e_data = _e.read(&mut locked);
298            //let _e_data = _e.read(&mut next_locked);
299
300            assert_eq!(&*d_data, &15);
301            assert_eq!(&*f_data, &45);
302        }
303        {
304            let (d_data, mut next_locked) = d.read_and(&mut locked);
305            let f_data = f.write(&mut next_locked);
306
307            // This won't compile
308            //let _e_data = _e.write(&mut locked);
309            //let _e_data = _e.write(&mut next_locked);
310
311            assert_eq!(&*d_data, &15);
312            assert_eq!(&*f_data, &45);
313        }
314    }
315
316    #[test]
317    fn test_lock_both() {
318        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
319        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
320        let mut locked = unsafe { Unlocked::new() };
321        {
322            let (a1_data, a2_data, _) = lock_both(&mut locked, &a1, &a2);
323            assert_eq!(&*a1_data, &15);
324            assert_eq!(&*a2_data, &30);
325        }
326        {
327            let (a2_data, a1_data, _) = lock_both(&mut locked, &a2, &a1);
328            assert_eq!(&*a1_data, &15);
329            assert_eq!(&*a2_data, &30);
330        }
331    }
332}