starnix_sync/
lock_sequence.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Tools for describing and enforcing lock acquisition order.
6//!
7//! To use these tools:
8//! 1. A lock level must be defined for each type of lock. This can be a simple enum.
9//! 2. Then a relation `LockedAfter` between these levels must be described,
10//! forming a graph. This graph must be acyclic, since a cycle would indicate
11//! a potential deadlock.
12//! 3. Each time a lock is acquired, it must be done using an object of a `Locked<P>`
13//! type, where `P` is any lock level that comes before the level `L` that is
14//! associated with this lock. Doing so yields a new object of type `Locked<L>`
15//! that can be used to acquire subsequent locks.
16//! 3. Each place where a lock is used must be marked with the maximum lock level
17//! that can be already acquired before attempting to acquire this lock. To do this,
18//! it takes a special marker object `Locked<P>` where `P` is a lock level that
19//! must come before the level associated in this lock in the graph. This object
20//! is then used to acquire the lock, and a new object `Locked<L>` is returned, with
21//! a new lock level `L` that comes after `P` in the lock ordering graph.
22//!
23//! ## Example
24//! See also tests for this crate.
25//!
26//! ```
27//! use std::sync::Mutex;
28//! use starnix_sync::{lock_ordering, lock::LockFor, relation::LockAfter, Unlocked};
29//!
30//! #[derive(Default)]
31//! struct HoldsLocks {
32//!    a: Mutex<u8>,
33//!    b: Mutex<u32>,
34//! }
35//!
36//! lock_ordering! {
37//!    // LockA is the top of the lock hierarchy.
38//!    Unlocked => LevelA,
39//!    // LockA can be acquired before LockB.
40//!    LevelA => LevelB,
41//! }
42//!
43//! impl LockFor<LockA> for HoldsLocks {
44//!    type Data = u8;
45//!    type Guard<'l> = std::sync::MutexGuard<'l, u8>
46//!        where Self: 'l;
47//!    fn lock(&self) -> Self::Guard<'_> {
48//!        self.a.lock().unwrap()
49//!    }
50//! }
51//!
52//! impl LockFor<LockB> for HoldsLocks {
53//!    type Data = u32;
54//!    type Guard<'l> = std::sync::MutexGuard<'l, u32>
55//!        where Self: 'l;
56//!    fn lock(&self) -> Self::Guard<'_> {
57//!        self.b.lock().unwrap()
58//!    }
59//! }
60//!
61//! // Accessing locked state looks like this:
62//!
63//! let state = HoldsLocks::default();
64//! // Create a new lock session with the "root" lock level (empty tuple).
65//! let mut locked = Unlocked::new();
66//! // Access locked state.
67//! let (a, mut locked_a) = locked.lock_and::<LockA, _>(&state);
68//! let b = locked_a.lock::<LockB, _>(&state);
69//! ```
70//!
71//! The [lock_ordering] macro provides definitions for lock levels and
72//! implementations of [LockAfter] for all the locks that are connected
73//! in the graph (one can be locked after another). It also prevents
74//! accidental lock ordering inversion introduced while defining the graph
75//! by detecting cycles in it.
76//!
77//! This won't compile:
78//! ```compile_fail
79//! lock_ordering!{
80//!     Unlocked => A,
81//!     A => B,
82//!     B => A,
83//! }
84//! ```
85//!
86//! The methods on [Locked] prevent out-of-order locking according to the
87//! specified lock relationships.
88//!
89//! This won't compile because `LockB` does not implement `LockBefore<LockA>`:
90//! ```compile_fail
91//! # use std::sync::Mutex;
92//! # use starnix_sync::{lock_ordering, lock::LockFor, Locked, Unlocked};
93//! #
94//! # #[derive(Default)]
95//! # struct HoldsLocks {
96//! #    a: Mutex<u8>,
97//! #    b: Mutex<u32>,
98//! # }
99//! #
100//! # lock_ordering! {
101//! #    // LockA is the top of the lock hierarchy.
102//! #    Unlocked => LockA,
103//! #    // LockA can be acquired before LockB.
104//! #    LockA => LockB,
105//! # }
106//! #
107//! # impl LockFor<LockA> for HoldsLocks {
108//! #    type Data = u8;
109//! #    type Guard<'l> = std::sync::MutexGuard<'l, u8>
110//! #        where Self: 'l;
111//! #    fn lock(&self) -> Self::Guard<'_> {
112//! #        self.a.lock().unwrap()
113//! #    }
114//! # }
115//! #
116//! # impl LockFor<LockB> for HoldsLocks {
117//! #     type Data = u32;
118//! #     type Guard<'l> = std::sync::MutexGuard<'l, u32>
119//! #         where Self: 'l;
120//! #     fn lock(&self) -> Self::Guard<'_> {
121//! #         self.b.lock().unwrap()
122//! #     }
123//! # }
124//! #
125//!
126//! let state = HoldsLocks::default();
127//! let mut locked = Unlocked::new();
128//!
129//! // Locking B without A is fine, but locking A after B is not.
130//! let (b, mut locked_b) = locked.lock_and::<LockB, _>(&state);
131//! // compile error: LockB does not implement LockBefore<LockA>
132//! let a = locked_b.lock::<LockA, _>(&state);
133//! ```
134//!
135//! Even if the lock guard goes out of scope, the new `Locked` instance returned
136//! by [Locked::lock_and] will prevent the original one from being used to
137//! access state. This doesn't work:
138//!
139//! ```compile_fail
140//! # use std::sync::Mutex;
141//! # use starnix_sync::{lock_ordering, lock::LockFor, Locked, Unlocked};
142//! #
143//! # #[derive(Default)]
144//! # struct HoldsLocks {
145//! #     a: Mutex<u8>,
146//! #     b: Mutex<u32>,
147//! # }
148//! #
149//! # lock_ordering! {
150//! #    // LockA is the top of the lock hierarchy.
151//! #    Unlocked => LockA,
152//! #    // LockA can be acquired before LockB.
153//! #    LockA => LockB,
154//! # }
155//! #
156//! # impl LockFor<LockA> for HoldsLocks {
157//! #     type Data = u8;
158//! #     type Guard<'l> = std::sync::MutexGuard<'l, u8>
159//! #         where Self: 'l;
160//! #     fn lock(&self) -> Self::Guard<'_> {
161//! #         self.a.lock().unwrap()
162//! #     }
163//! # }
164//! #
165//! # impl LockFor<LockB> for HoldsLocks {
166//! #     type Data = u32;
167//! #     type Guard<'l> = std::sync::MutexGuard<'l, u32>
168//! #         where Self: 'l;
169//! #     fn lock(&self) -> Self::Guard<'_> {
170//! #         self.b.lock().unwrap()
171//! #     }
172//! # }
173//!
174//! let state = HoldsLocks::default();
175//! let mut locked = Unlocked::new();
176//!
177//! let (b, mut locked_b) = locked.lock_and::<LockB, _>();
178//! drop(b);
179//! let b = locked_b.lock::<LockB, _>(&state);
180//! // Won't work; `locked` is mutably borrowed by `locked_b`.
181//! let a = locked.lock::<LockA, _>(&state);
182//! ```
183
184use core::marker::PhantomData;
185use static_assertions::const_assert_eq;
186
187pub use crate::{LockBefore, LockEqualOrBefore, LockFor, RwLockFor};
188
189/// Enforcement mechanism for lock ordering.
190///
191/// `Locked` is a context that holds the lock level marker. Any state that
192/// requires a lock to access should acquire this lock by calling `lock_and`
193/// on a `Locked` object that is of an appropriate lock level. Acquiring
194/// a lock in this way produces the guard and a new `Locked` instance
195/// (with a different lock level) that mutably borrows from the original
196/// instance. This means the original instance can't be used to acquire
197/// new locks until the new instance leaves scope.
198pub struct Locked<'a, L>(PhantomData<&'a L>);
199
200/// "Highest" lock level
201///
202/// The lock level for the thing returned by `Locked::new`. Users of this crate
203/// should implement `LockAfter<Unlocked>` for the root of any lock ordering
204/// trees.
205pub enum Unlocked {}
206
207const_assert_eq!(std::mem::size_of::<Locked<'static, Unlocked>>(), 0);
208
209impl Unlocked {
210    /// Entry point for locked access.
211    ///
212    /// `Unlocked` is the "root" lock level and can be acquired before any lock.
213    ///
214    /// # Safety
215    /// `Unlocked` should only be used before any lock in the program has been acquired.
216    #[inline(always)]
217    pub unsafe fn new() -> Locked<'static, Unlocked> {
218        Locked::<'static, Unlocked>(Default::default())
219    }
220}
221impl LockEqualOrBefore<Unlocked> for Unlocked {}
222
223// It's important that the lifetime on `Locked` here be anonymous. That means
224// that the lifetimes in the returned `Locked` objects below are inferred to
225// be the lifetimes of the references to self (mutable or immutable).
226impl<L> Locked<'_, L> {
227    /// Acquire the given lock.
228    ///
229    /// This requires that `M` can be locked after `L`.
230    #[inline(always)]
231    pub fn lock<'a, M, S>(&'a mut self, source: &'a S) -> S::Guard<'a>
232    where
233        M: 'a,
234        S: LockFor<M>,
235        L: LockBefore<M>,
236    {
237        let (data, _) = self.lock_and::<M, S>(source);
238        data
239    }
240
241    /// Acquire the given lock and a new locked context.
242    ///
243    /// This requires that `M` can be locked after `L`.
244    #[inline(always)]
245    pub fn lock_and<'a, M, S>(&'a mut self, source: &'a S) -> (S::Guard<'a>, Locked<'a, M>)
246    where
247        M: 'a,
248        S: LockFor<M>,
249        L: LockBefore<M>,
250    {
251        let data = S::lock(source);
252        (data, Locked::<'a, M>(PhantomData::default()))
253    }
254
255    /// Acquire two locks that are on the same level, in a consistent order (sorted by memory address) and return both guards
256    /// as well as the new locked context.
257    ///
258    /// This requires that `M` can be locked after `L`.
259    #[inline(always)]
260    pub fn lock_both_and<'a, M, S>(
261        &'a mut self,
262        source1: &'a S,
263        source2: &'a S,
264    ) -> (S::Guard<'a>, S::Guard<'a>, Locked<'a, M>)
265    where
266        M: 'a,
267        S: LockFor<M>,
268        L: LockBefore<M>,
269    {
270        let ptr1: *const S = source1;
271        let ptr2: *const S = source2;
272        if ptr1 < ptr2 {
273            let data1 = S::lock(source1);
274            let data2 = S::lock(source2);
275            (data1, data2, Locked::<'a, M>(PhantomData::default()))
276        } else {
277            let data2 = S::lock(source2);
278            let data1 = S::lock(source1);
279            (data1, data2, Locked::<'a, M>(PhantomData::default()))
280        }
281    }
282    /// Acquire two locks that are on the same level, in a consistent order (sorted by memory address) and return both guards.
283    ///
284    /// This requires that `M` can be locked after `L`.
285    #[inline(always)]
286    pub fn lock_both<'a, M, S>(
287        &'a mut self,
288        source1: &'a S,
289        source2: &'a S,
290    ) -> (S::Guard<'a>, S::Guard<'a>)
291    where
292        M: 'a,
293        S: LockFor<M>,
294        L: LockBefore<M>,
295    {
296        let (data1, data2, _) = self.lock_both_and(source1, source2);
297        (data1, data2)
298    }
299
300    /// Attempt to acquire the given read lock and a new locked context.
301    ///
302    /// For accessing state via reader/writer locks. This requires that `M` can
303    /// be locked after `L`.
304    #[inline(always)]
305    pub fn read_lock_and<'a, M, S>(&'a mut self, source: &'a S) -> (S::ReadGuard<'a>, Locked<'a, M>)
306    where
307        M: 'a,
308        S: RwLockFor<M>,
309        L: LockBefore<M>,
310    {
311        let data = S::read_lock(source);
312        (data, Locked::<'a, M>(PhantomData::default()))
313    }
314
315    /// Attempt to acquire the given read lock.
316    ///
317    /// For accessing state via reader/writer locks. This requires that `M` can
318    /// be locked after `L`.
319    #[inline(always)]
320    pub fn read_lock<'a, M, S>(&'a mut self, source: &'a S) -> S::ReadGuard<'a>
321    where
322        M: 'a,
323        S: RwLockFor<M>,
324        L: LockBefore<M>,
325    {
326        let (data, _) = self.read_lock_and::<M, S>(source);
327        data
328    }
329
330    /// Attempt to acquire the given write lock and a new locked context.
331    ///
332    /// For accessing state via reader/writer locks. This requires that `M` can
333    /// be locked after `L`.
334    #[inline(always)]
335    pub fn write_lock_and<'a, M, S>(
336        &'a mut self,
337        source: &'a S,
338    ) -> (S::WriteGuard<'a>, Locked<'a, M>)
339    where
340        M: 'a,
341        S: RwLockFor<M>,
342        L: LockBefore<M>,
343    {
344        let data = S::write_lock(source);
345        (data, Locked::<'a, M>(PhantomData::default()))
346    }
347
348    /// Attempt to acquire the given write lock.
349    ///
350    /// For accessing state via reader/writer locks. This requires that `M` can
351    /// be locked after `L`.
352    #[inline(always)]
353    pub fn write_lock<'a, M, S>(&'a mut self, source: &'a S) -> S::WriteGuard<'a>
354    where
355        M: 'a,
356        S: RwLockFor<M>,
357        L: LockBefore<M>,
358    {
359        let (data, _) = self.write_lock_and::<M, S>(source);
360        data
361    }
362
363    /// Restrict locking as if a lock was acquired.
364    ///
365    /// Like `lock_and` but doesn't actually acquire the lock `M`. This is
366    /// safe because any locks that could be acquired with the lock `M` held can
367    /// also be acquired without `M` being held.
368    #[inline(always)]
369    pub fn cast_locked<'a, M>(&'a mut self) -> Locked<'a, M>
370    where
371        M: 'a,
372        L: LockEqualOrBefore<M>,
373    {
374        Locked::<'a, M>(PhantomData::default())
375    }
376
377    #[inline(always)]
378    pub fn cast_locked_by_value<'a, M>(_locked: Locked<'a, L>) -> Locked<'a, M>
379    where
380        M: 'a,
381        L: LockEqualOrBefore<M>,
382    {
383        Locked::<'a, M>(PhantomData::default())
384    }
385}
386
387#[cfg(test)]
388mod test {
389    use std::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
390
391    #[test]
392    fn example() {
393        use crate::{lock_ordering, Unlocked};
394
395        #[derive(Default)]
396        pub struct HoldsLocks {
397            a: Mutex<u8>,
398            b: Mutex<u32>,
399        }
400
401        lock_ordering! {
402            // LockA is the top of the lock hierarchy.
403            Unlocked => LockA,
404            // LockA can be acquired before LockB.
405            LockA => LockB,
406        }
407
408        impl LockFor<LockA> for HoldsLocks {
409            type Data = u8;
410            type Guard<'l>
411                = std::sync::MutexGuard<'l, u8>
412            where
413                Self: 'l;
414            fn lock(&self) -> Self::Guard<'_> {
415                self.a.lock().unwrap()
416            }
417        }
418
419        impl LockFor<LockB> for HoldsLocks {
420            type Data = u32;
421            type Guard<'l>
422                = std::sync::MutexGuard<'l, u32>
423            where
424                Self: 'l;
425            fn lock(&self) -> Self::Guard<'_> {
426                self.b.lock().unwrap()
427            }
428        }
429
430        // Accessing locked state looks like this:
431
432        let state = HoldsLocks::default();
433        // Create a new lock session with the "root" lock level (empty tuple).
434        let mut locked = unsafe { Unlocked::new() };
435        // Access locked state.
436        let (_a, mut locked_a) = locked.lock_and::<LockA, _>(&state);
437        let _b = locked_a.lock::<LockB, _>(&state);
438    }
439
440    mod lock_levels {
441        use crate::Unlocked;
442        use lock_ordering_macro::lock_ordering;
443        // Lock ordering tree:
444        // A -> B -> {C, D, E -> F, G -> H}
445        lock_ordering! {
446            Unlocked => A,
447            A => B,
448            B => C,
449            B => D,
450            B => E,
451            E => F,
452            B => G,
453            G => H,
454        }
455    }
456
457    use crate::{LockFor, RwLockFor, Unlocked};
458    use lock_levels::{A, B, C, D, E, F, G, H};
459
460    /// Data type with multiple locked fields.
461    #[derive(Default)]
462    pub struct Data {
463        a: Mutex<u8>,
464        b: Mutex<u16>,
465        c: Mutex<u64>,
466        d: RwLock<u128>,
467        e: Mutex<Mutex<u8>>,
468        g: Mutex<Vec<Mutex<u8>>>,
469        u: usize,
470    }
471
472    impl LockFor<A> for Data {
473        type Data = u8;
474        type Guard<'l> = MutexGuard<'l, u8>;
475        fn lock(&self) -> Self::Guard<'_> {
476            self.a.lock().unwrap()
477        }
478    }
479
480    impl LockFor<B> for Data {
481        type Data = u16;
482        type Guard<'l> = MutexGuard<'l, u16>;
483        fn lock(&self) -> Self::Guard<'_> {
484            self.b.lock().unwrap()
485        }
486    }
487
488    impl LockFor<C> for Data {
489        type Data = u64;
490        type Guard<'l> = MutexGuard<'l, u64>;
491        fn lock(&self) -> Self::Guard<'_> {
492            self.c.lock().unwrap()
493        }
494    }
495
496    impl RwLockFor<D> for Data {
497        type Data = u128;
498        type ReadGuard<'l> = RwLockReadGuard<'l, u128>;
499        type WriteGuard<'l> = RwLockWriteGuard<'l, u128>;
500        fn read_lock(&self) -> Self::ReadGuard<'_> {
501            self.d.read().unwrap()
502        }
503        fn write_lock(&self) -> Self::WriteGuard<'_> {
504            self.d.write().unwrap()
505        }
506    }
507
508    impl LockFor<E> for Data {
509        type Data = Mutex<u8>;
510        type Guard<'l> = MutexGuard<'l, Mutex<u8>>;
511        fn lock(&self) -> Self::Guard<'_> {
512            self.e.lock().unwrap()
513        }
514    }
515
516    impl LockFor<F> for Mutex<u8> {
517        type Data = u8;
518        type Guard<'l> = MutexGuard<'l, u8>;
519        fn lock(&self) -> Self::Guard<'_> {
520            self.lock().unwrap()
521        }
522    }
523
524    impl LockFor<G> for Data {
525        type Data = Vec<Mutex<u8>>;
526        type Guard<'l> = MutexGuard<'l, Vec<Mutex<u8>>>;
527        fn lock(&self) -> Self::Guard<'_> {
528            self.g.lock().unwrap()
529        }
530    }
531
532    impl LockFor<H> for Mutex<u8> {
533        type Data = u8;
534        type Guard<'l> = MutexGuard<'l, u8>;
535        fn lock(&self) -> Self::Guard<'_> {
536            self.lock().unwrap()
537        }
538    }
539
540    #[derive(Debug)]
541    #[allow(dead_code)]
542    struct NotPresent;
543
544    #[test]
545    fn lock_a_then_c() {
546        let data = Data::default();
547
548        let mut w = unsafe { Unlocked::new() };
549        let (_a, mut wa) = w.lock_and::<A, _>(&data);
550        let (_c, _wc) = wa.lock_and::<C, _>(&data);
551        // This won't compile!
552        // let _b = _wc.lock::<B, _>(&data);
553    }
554
555    #[test]
556    fn cast_a_then_c() {
557        let data = Data::default();
558
559        let mut w = unsafe { Unlocked::new() };
560        let mut wa = w.cast_locked::<A>();
561        let (_c, _wc) = wa.lock_and::<C, _>(&data);
562        // This should not compile:
563        // let _b = w.lock::<B, _>(&data);
564    }
565
566    #[test]
567    fn unlocked_access_does_not_prevent_locking() {
568        let data = Data { a: Mutex::new(15), u: 34, ..Data::default() };
569
570        let mut locked = unsafe { Unlocked::new() };
571        let u = &data.u;
572
573        // Prove that `u` does not prevent locked state from being accessed.
574        let a = locked.lock::<A, _>(&data);
575
576        assert_eq!(u, &34);
577        assert_eq!(&*a, &15);
578    }
579
580    #[test]
581    fn nested_locks() {
582        let data = Data { e: Mutex::new(Mutex::new(1)), ..Data::default() };
583
584        let mut locked = unsafe { Unlocked::new() };
585        let (e, mut next_locked) = locked.lock_and::<E, _>(&data);
586        let v = next_locked.lock::<F, _>(&*e);
587        assert_eq!(*v, 1);
588    }
589
590    #[test]
591    fn rw_lock() {
592        let data = Data { d: RwLock::new(1), ..Data::default() };
593
594        let mut locked = unsafe { Unlocked::new() };
595        {
596            let mut d = locked.write_lock::<D, _>(&data);
597            *d = 10;
598        }
599        let d = locked.read_lock::<D, _>(&data);
600        assert_eq!(*d, 10);
601    }
602
603    #[test]
604    fn collections() {
605        let data = Data { g: Mutex::new(vec![Mutex::new(0), Mutex::new(1)]), ..Data::default() };
606
607        let mut locked = unsafe { Unlocked::new() };
608        let (g, mut next_locked) = locked.lock_and::<G, _>(&data);
609        let v = next_locked.lock::<H, _>(&g[1]);
610        assert_eq!(*v, 1);
611    }
612
613    #[test]
614    fn lock_same_level() {
615        let data1 = Data { a: Mutex::new(5), b: Mutex::new(15), ..Data::default() };
616        let data2 = Data { a: Mutex::new(10), b: Mutex::new(20), ..Data::default() };
617        let mut locked = unsafe { Unlocked::new() };
618        {
619            let (a1, a2, mut new_locked) = locked.lock_both_and::<A, _>(&data1, &data2);
620            assert_eq!(*a1, 5);
621            assert_eq!(*a2, 10);
622            let (b1, b2) = new_locked.lock_both::<B, _>(&data1, &data2);
623            assert_eq!(*b1, 15);
624            assert_eq!(*b2, 20);
625        }
626        {
627            let (a2, a1) = locked.lock_both::<A, _>(&data2, &data1);
628            assert_eq!(*a1, 5);
629            assert_eq!(*a2, 10);
630        }
631    }
632}