starnix_sync/lock_sequence.rs
1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Tools for describing and enforcing lock acquisition order.
6//!
7//! To use these tools:
8//! 1. A lock level must be defined for each type of lock. This can be a simple enum.
9//! 2. Then a relation `LockedAfter` between these levels must be described,
10//! forming a graph. This graph must be acyclic, since a cycle would indicate
11//! a potential deadlock.
12//! 3. Each time a lock is acquired, it must be done using an object of a `Locked<P>`
13//! type, where `P` is any lock level that comes before the level `L` that is
14//! associated with this lock. Doing so yields a new object of type `Locked<L>`
15//! that can be used to acquire subsequent locks.
16//! 3. Each place where a lock is used must be marked with the maximum lock level
17//! that can be already acquired before attempting to acquire this lock. To do this,
18//! it takes a special marker object `Locked<P>` where `P` is a lock level that
19//! must come before the level associated in this lock in the graph. This object
20//! is then used to acquire the lock, and a new object `Locked<L>` is returned, with
21//! a new lock level `L` that comes after `P` in the lock ordering graph.
22//!
23//! ## Example
24//! See also tests for this crate.
25//!
26//! ```
27//! use std::sync::Mutex;
28//! use starnix_sync::{lock_ordering, lock::LockFor, relation::LockAfter, Unlocked};
29//!
30//! #[derive(Default)]
31//! struct HoldsLocks {
32//! a: Mutex<u8>,
33//! b: Mutex<u32>,
34//! }
35//!
36//! lock_ordering! {
37//! // LockA is the top of the lock hierarchy.
38//! Unlocked => LevelA,
39//! // LockA can be acquired before LockB.
40//! LevelA => LevelB,
41//! }
42//!
43//! impl LockFor<LockA> for HoldsLocks {
44//! type Data = u8;
45//! type Guard<'l> = std::sync::MutexGuard<'l, u8>
46//! where Self: 'l;
47//! fn lock(&self) -> Self::Guard<'_> {
48//! self.a.lock().unwrap()
49//! }
50//! }
51//!
52//! impl LockFor<LockB> for HoldsLocks {
53//! type Data = u32;
54//! type Guard<'l> = std::sync::MutexGuard<'l, u32>
55//! where Self: 'l;
56//! fn lock(&self) -> Self::Guard<'_> {
57//! self.b.lock().unwrap()
58//! }
59//! }
60//!
61//! // Accessing locked state looks like this:
62//!
63//! let state = HoldsLocks::default();
64//! // Create a new lock session with the "root" lock level (empty tuple).
65//! let mut locked = Unlocked::new();
66//! // Access locked state.
67//! let (a, mut locked_a) = locked.lock_and::<LockA, _>(&state);
68//! let b = locked_a.lock::<LockB, _>(&state);
69//! ```
70//!
71//! The [lock_ordering] macro provides definitions for lock levels and
72//! implementations of [LockAfter] for all the locks that are connected
73//! in the graph (one can be locked after another). It also prevents
74//! accidental lock ordering inversion introduced while defining the graph
75//! by detecting cycles in it.
76//!
77//! This won't compile:
78//! ```compile_fail
79//! lock_ordering!{
80//! Unlocked => A,
81//! A => B,
82//! B => A,
83//! }
84//! ```
85//!
86//! The methods on [Locked] prevent out-of-order locking according to the
87//! specified lock relationships.
88//!
89//! This won't compile because `LockB` does not implement `LockBefore<LockA>`:
90//! ```compile_fail
91//! # use std::sync::Mutex;
92//! # use starnix_sync::{lock_ordering, lock::LockFor, Locked, Unlocked};
93//! #
94//! # #[derive(Default)]
95//! # struct HoldsLocks {
96//! # a: Mutex<u8>,
97//! # b: Mutex<u32>,
98//! # }
99//! #
100//! # lock_ordering! {
101//! # // LockA is the top of the lock hierarchy.
102//! # Unlocked => LockA,
103//! # // LockA can be acquired before LockB.
104//! # LockA => LockB,
105//! # }
106//! #
107//! # impl LockFor<LockA> for HoldsLocks {
108//! # type Data = u8;
109//! # type Guard<'l> = std::sync::MutexGuard<'l, u8>
110//! # where Self: 'l;
111//! # fn lock(&self) -> Self::Guard<'_> {
112//! # self.a.lock().unwrap()
113//! # }
114//! # }
115//! #
116//! # impl LockFor<LockB> for HoldsLocks {
117//! # type Data = u32;
118//! # type Guard<'l> = std::sync::MutexGuard<'l, u32>
119//! # where Self: 'l;
120//! # fn lock(&self) -> Self::Guard<'_> {
121//! # self.b.lock().unwrap()
122//! # }
123//! # }
124//! #
125//!
126//! let state = HoldsLocks::default();
127//! let mut locked = Unlocked::new();
128//!
129//! // Locking B without A is fine, but locking A after B is not.
130//! let (b, mut locked_b) = locked.lock_and::<LockB, _>(&state);
131//! // compile error: LockB does not implement LockBefore<LockA>
132//! let a = locked_b.lock::<LockA, _>(&state);
133//! ```
134//!
135//! Even if the lock guard goes out of scope, the new `Locked` instance returned
136//! by [Locked::lock_and] will prevent the original one from being used to
137//! access state. This doesn't work:
138//!
139//! ```compile_fail
140//! # use std::sync::Mutex;
141//! # use starnix_sync::{lock_ordering, lock::LockFor, Locked, Unlocked};
142//! #
143//! # #[derive(Default)]
144//! # struct HoldsLocks {
145//! # a: Mutex<u8>,
146//! # b: Mutex<u32>,
147//! # }
148//! #
149//! # lock_ordering! {
150//! # // LockA is the top of the lock hierarchy.
151//! # Unlocked => LockA,
152//! # // LockA can be acquired before LockB.
153//! # LockA => LockB,
154//! # }
155//! #
156//! # impl LockFor<LockA> for HoldsLocks {
157//! # type Data = u8;
158//! # type Guard<'l> = std::sync::MutexGuard<'l, u8>
159//! # where Self: 'l;
160//! # fn lock(&self) -> Self::Guard<'_> {
161//! # self.a.lock().unwrap()
162//! # }
163//! # }
164//! #
165//! # impl LockFor<LockB> for HoldsLocks {
166//! # type Data = u32;
167//! # type Guard<'l> = std::sync::MutexGuard<'l, u32>
168//! # where Self: 'l;
169//! # fn lock(&self) -> Self::Guard<'_> {
170//! # self.b.lock().unwrap()
171//! # }
172//! # }
173//!
174//! let state = HoldsLocks::default();
175//! let mut locked = Unlocked::new();
176//!
177//! let (b, mut locked_b) = locked.lock_and::<LockB, _>();
178//! drop(b);
179//! let b = locked_b.lock::<LockB, _>(&state);
180//! // Won't work; `locked` is mutably borrowed by `locked_b`.
181//! let a = locked.lock::<LockA, _>(&state);
182//! ```
183
184use core::marker::PhantomData;
185use static_assertions::const_assert_eq;
186
187pub use crate::{LockBefore, LockEqualOrBefore, LockFor, RwLockFor};
188
189/// Enforcement mechanism for lock ordering.
190///
191/// `Locked` is a context that holds the lock level marker. Any state that
192/// requires a lock to access should acquire this lock by calling `lock_and`
193/// on a `Locked` object that is of an appropriate lock level. Acquiring
194/// a lock in this way produces the guard and a new `Locked` instance
195/// (with a different lock level) that mutably borrows from the original
196/// instance. This means the original instance can't be used to acquire
197/// new locks until the new instance leaves scope.
198pub struct Locked<'a, L>(PhantomData<&'a L>);
199
200/// "Highest" lock level
201///
202/// The lock level for the thing returned by `Locked::new`. Users of this crate
203/// should implement `LockAfter<Unlocked>` for the root of any lock ordering
204/// trees.
205pub enum Unlocked {}
206
207const_assert_eq!(std::mem::size_of::<Locked<'static, Unlocked>>(), 0);
208
209impl Unlocked {
210 /// Entry point for locked access.
211 ///
212 /// `Unlocked` is the "root" lock level and can be acquired before any lock.
213 ///
214 /// # Safety
215 /// `Unlocked` should only be used before any lock in the program has been acquired.
216 #[inline(always)]
217 pub unsafe fn new() -> Locked<'static, Unlocked> {
218 Locked::<'static, Unlocked>(Default::default())
219 }
220}
221impl LockEqualOrBefore<Unlocked> for Unlocked {}
222
223// It's important that the lifetime on `Locked` here be anonymous. That means
224// that the lifetimes in the returned `Locked` objects below are inferred to
225// be the lifetimes of the references to self (mutable or immutable).
226impl<L> Locked<'_, L> {
227 /// Acquire the given lock.
228 ///
229 /// This requires that `M` can be locked after `L`.
230 #[inline(always)]
231 pub fn lock<'a, M, S>(&'a mut self, source: &'a S) -> S::Guard<'a>
232 where
233 M: 'a,
234 S: LockFor<M>,
235 L: LockBefore<M>,
236 {
237 let (data, _) = self.lock_and::<M, S>(source);
238 data
239 }
240
241 /// Acquire the given lock and a new locked context.
242 ///
243 /// This requires that `M` can be locked after `L`.
244 #[inline(always)]
245 pub fn lock_and<'a, M, S>(&'a mut self, source: &'a S) -> (S::Guard<'a>, Locked<'a, M>)
246 where
247 M: 'a,
248 S: LockFor<M>,
249 L: LockBefore<M>,
250 {
251 let data = S::lock(source);
252 (data, Locked::<'a, M>(PhantomData::default()))
253 }
254
255 /// Acquire two locks that are on the same level, in a consistent order (sorted by memory address) and return both guards
256 /// as well as the new locked context.
257 ///
258 /// This requires that `M` can be locked after `L`.
259 #[inline(always)]
260 pub fn lock_both_and<'a, M, S>(
261 &'a mut self,
262 source1: &'a S,
263 source2: &'a S,
264 ) -> (S::Guard<'a>, S::Guard<'a>, Locked<'a, M>)
265 where
266 M: 'a,
267 S: LockFor<M>,
268 L: LockBefore<M>,
269 {
270 let ptr1: *const S = source1;
271 let ptr2: *const S = source2;
272 if ptr1 < ptr2 {
273 let data1 = S::lock(source1);
274 let data2 = S::lock(source2);
275 (data1, data2, Locked::<'a, M>(PhantomData::default()))
276 } else {
277 let data2 = S::lock(source2);
278 let data1 = S::lock(source1);
279 (data1, data2, Locked::<'a, M>(PhantomData::default()))
280 }
281 }
282 /// Acquire two locks that are on the same level, in a consistent order (sorted by memory address) and return both guards.
283 ///
284 /// This requires that `M` can be locked after `L`.
285 #[inline(always)]
286 pub fn lock_both<'a, M, S>(
287 &'a mut self,
288 source1: &'a S,
289 source2: &'a S,
290 ) -> (S::Guard<'a>, S::Guard<'a>)
291 where
292 M: 'a,
293 S: LockFor<M>,
294 L: LockBefore<M>,
295 {
296 let (data1, data2, _) = self.lock_both_and(source1, source2);
297 (data1, data2)
298 }
299
300 /// Attempt to acquire the given read lock and a new locked context.
301 ///
302 /// For accessing state via reader/writer locks. This requires that `M` can
303 /// be locked after `L`.
304 #[inline(always)]
305 pub fn read_lock_and<'a, M, S>(&'a mut self, source: &'a S) -> (S::ReadGuard<'a>, Locked<'a, M>)
306 where
307 M: 'a,
308 S: RwLockFor<M>,
309 L: LockBefore<M>,
310 {
311 let data = S::read_lock(source);
312 (data, Locked::<'a, M>(PhantomData::default()))
313 }
314
315 /// Attempt to acquire the given read lock.
316 ///
317 /// For accessing state via reader/writer locks. This requires that `M` can
318 /// be locked after `L`.
319 #[inline(always)]
320 pub fn read_lock<'a, M, S>(&'a mut self, source: &'a S) -> S::ReadGuard<'a>
321 where
322 M: 'a,
323 S: RwLockFor<M>,
324 L: LockBefore<M>,
325 {
326 let (data, _) = self.read_lock_and::<M, S>(source);
327 data
328 }
329
330 /// Attempt to acquire the given write lock and a new locked context.
331 ///
332 /// For accessing state via reader/writer locks. This requires that `M` can
333 /// be locked after `L`.
334 #[inline(always)]
335 pub fn write_lock_and<'a, M, S>(
336 &'a mut self,
337 source: &'a S,
338 ) -> (S::WriteGuard<'a>, Locked<'a, M>)
339 where
340 M: 'a,
341 S: RwLockFor<M>,
342 L: LockBefore<M>,
343 {
344 let data = S::write_lock(source);
345 (data, Locked::<'a, M>(PhantomData::default()))
346 }
347
348 /// Attempt to acquire the given write lock.
349 ///
350 /// For accessing state via reader/writer locks. This requires that `M` can
351 /// be locked after `L`.
352 #[inline(always)]
353 pub fn write_lock<'a, M, S>(&'a mut self, source: &'a S) -> S::WriteGuard<'a>
354 where
355 M: 'a,
356 S: RwLockFor<M>,
357 L: LockBefore<M>,
358 {
359 let (data, _) = self.write_lock_and::<M, S>(source);
360 data
361 }
362
363 /// Restrict locking as if a lock was acquired.
364 ///
365 /// Like `lock_and` but doesn't actually acquire the lock `M`. This is
366 /// safe because any locks that could be acquired with the lock `M` held can
367 /// also be acquired without `M` being held.
368 #[inline(always)]
369 pub fn cast_locked<'a, M>(&'a mut self) -> Locked<'a, M>
370 where
371 M: 'a,
372 L: LockEqualOrBefore<M>,
373 {
374 Locked::<'a, M>(PhantomData::default())
375 }
376
377 #[inline(always)]
378 pub fn cast_locked_by_value<'a, M>(_locked: Locked<'a, L>) -> Locked<'a, M>
379 where
380 M: 'a,
381 L: LockEqualOrBefore<M>,
382 {
383 Locked::<'a, M>(PhantomData::default())
384 }
385}
386
387#[cfg(test)]
388mod test {
389 use std::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
390
391 #[test]
392 fn example() {
393 use crate::{lock_ordering, Unlocked};
394
395 #[derive(Default)]
396 pub struct HoldsLocks {
397 a: Mutex<u8>,
398 b: Mutex<u32>,
399 }
400
401 lock_ordering! {
402 // LockA is the top of the lock hierarchy.
403 Unlocked => LockA,
404 // LockA can be acquired before LockB.
405 LockA => LockB,
406 }
407
408 impl LockFor<LockA> for HoldsLocks {
409 type Data = u8;
410 type Guard<'l>
411 = std::sync::MutexGuard<'l, u8>
412 where
413 Self: 'l;
414 fn lock(&self) -> Self::Guard<'_> {
415 self.a.lock().unwrap()
416 }
417 }
418
419 impl LockFor<LockB> for HoldsLocks {
420 type Data = u32;
421 type Guard<'l>
422 = std::sync::MutexGuard<'l, u32>
423 where
424 Self: 'l;
425 fn lock(&self) -> Self::Guard<'_> {
426 self.b.lock().unwrap()
427 }
428 }
429
430 // Accessing locked state looks like this:
431
432 let state = HoldsLocks::default();
433 // Create a new lock session with the "root" lock level (empty tuple).
434 let mut locked = unsafe { Unlocked::new() };
435 // Access locked state.
436 let (_a, mut locked_a) = locked.lock_and::<LockA, _>(&state);
437 let _b = locked_a.lock::<LockB, _>(&state);
438 }
439
440 mod lock_levels {
441 use crate::Unlocked;
442 use lock_ordering_macro::lock_ordering;
443 // Lock ordering tree:
444 // A -> B -> {C, D, E -> F, G -> H}
445 lock_ordering! {
446 Unlocked => A,
447 A => B,
448 B => C,
449 B => D,
450 B => E,
451 E => F,
452 B => G,
453 G => H,
454 }
455 }
456
457 use crate::{LockFor, RwLockFor, Unlocked};
458 use lock_levels::{A, B, C, D, E, F, G, H};
459
460 /// Data type with multiple locked fields.
461 #[derive(Default)]
462 pub struct Data {
463 a: Mutex<u8>,
464 b: Mutex<u16>,
465 c: Mutex<u64>,
466 d: RwLock<u128>,
467 e: Mutex<Mutex<u8>>,
468 g: Mutex<Vec<Mutex<u8>>>,
469 u: usize,
470 }
471
472 impl LockFor<A> for Data {
473 type Data = u8;
474 type Guard<'l> = MutexGuard<'l, u8>;
475 fn lock(&self) -> Self::Guard<'_> {
476 self.a.lock().unwrap()
477 }
478 }
479
480 impl LockFor<B> for Data {
481 type Data = u16;
482 type Guard<'l> = MutexGuard<'l, u16>;
483 fn lock(&self) -> Self::Guard<'_> {
484 self.b.lock().unwrap()
485 }
486 }
487
488 impl LockFor<C> for Data {
489 type Data = u64;
490 type Guard<'l> = MutexGuard<'l, u64>;
491 fn lock(&self) -> Self::Guard<'_> {
492 self.c.lock().unwrap()
493 }
494 }
495
496 impl RwLockFor<D> for Data {
497 type Data = u128;
498 type ReadGuard<'l> = RwLockReadGuard<'l, u128>;
499 type WriteGuard<'l> = RwLockWriteGuard<'l, u128>;
500 fn read_lock(&self) -> Self::ReadGuard<'_> {
501 self.d.read().unwrap()
502 }
503 fn write_lock(&self) -> Self::WriteGuard<'_> {
504 self.d.write().unwrap()
505 }
506 }
507
508 impl LockFor<E> for Data {
509 type Data = Mutex<u8>;
510 type Guard<'l> = MutexGuard<'l, Mutex<u8>>;
511 fn lock(&self) -> Self::Guard<'_> {
512 self.e.lock().unwrap()
513 }
514 }
515
516 impl LockFor<F> for Mutex<u8> {
517 type Data = u8;
518 type Guard<'l> = MutexGuard<'l, u8>;
519 fn lock(&self) -> Self::Guard<'_> {
520 self.lock().unwrap()
521 }
522 }
523
524 impl LockFor<G> for Data {
525 type Data = Vec<Mutex<u8>>;
526 type Guard<'l> = MutexGuard<'l, Vec<Mutex<u8>>>;
527 fn lock(&self) -> Self::Guard<'_> {
528 self.g.lock().unwrap()
529 }
530 }
531
532 impl LockFor<H> for Mutex<u8> {
533 type Data = u8;
534 type Guard<'l> = MutexGuard<'l, u8>;
535 fn lock(&self) -> Self::Guard<'_> {
536 self.lock().unwrap()
537 }
538 }
539
540 #[derive(Debug)]
541 #[allow(dead_code)]
542 struct NotPresent;
543
544 #[test]
545 fn lock_a_then_c() {
546 let data = Data::default();
547
548 let mut w = unsafe { Unlocked::new() };
549 let (_a, mut wa) = w.lock_and::<A, _>(&data);
550 let (_c, _wc) = wa.lock_and::<C, _>(&data);
551 // This won't compile!
552 // let _b = _wc.lock::<B, _>(&data);
553 }
554
555 #[test]
556 fn cast_a_then_c() {
557 let data = Data::default();
558
559 let mut w = unsafe { Unlocked::new() };
560 let mut wa = w.cast_locked::<A>();
561 let (_c, _wc) = wa.lock_and::<C, _>(&data);
562 // This should not compile:
563 // let _b = w.lock::<B, _>(&data);
564 }
565
566 #[test]
567 fn unlocked_access_does_not_prevent_locking() {
568 let data = Data { a: Mutex::new(15), u: 34, ..Data::default() };
569
570 let mut locked = unsafe { Unlocked::new() };
571 let u = &data.u;
572
573 // Prove that `u` does not prevent locked state from being accessed.
574 let a = locked.lock::<A, _>(&data);
575
576 assert_eq!(u, &34);
577 assert_eq!(&*a, &15);
578 }
579
580 #[test]
581 fn nested_locks() {
582 let data = Data { e: Mutex::new(Mutex::new(1)), ..Data::default() };
583
584 let mut locked = unsafe { Unlocked::new() };
585 let (e, mut next_locked) = locked.lock_and::<E, _>(&data);
586 let v = next_locked.lock::<F, _>(&*e);
587 assert_eq!(*v, 1);
588 }
589
590 #[test]
591 fn rw_lock() {
592 let data = Data { d: RwLock::new(1), ..Data::default() };
593
594 let mut locked = unsafe { Unlocked::new() };
595 {
596 let mut d = locked.write_lock::<D, _>(&data);
597 *d = 10;
598 }
599 let d = locked.read_lock::<D, _>(&data);
600 assert_eq!(*d, 10);
601 }
602
603 #[test]
604 fn collections() {
605 let data = Data { g: Mutex::new(vec![Mutex::new(0), Mutex::new(1)]), ..Data::default() };
606
607 let mut locked = unsafe { Unlocked::new() };
608 let (g, mut next_locked) = locked.lock_and::<G, _>(&data);
609 let v = next_locked.lock::<H, _>(&g[1]);
610 assert_eq!(*v, 1);
611 }
612
613 #[test]
614 fn lock_same_level() {
615 let data1 = Data { a: Mutex::new(5), b: Mutex::new(15), ..Data::default() };
616 let data2 = Data { a: Mutex::new(10), b: Mutex::new(20), ..Data::default() };
617 let mut locked = unsafe { Unlocked::new() };
618 {
619 let (a1, a2, mut new_locked) = locked.lock_both_and::<A, _>(&data1, &data2);
620 assert_eq!(*a1, 5);
621 assert_eq!(*a2, 10);
622 let (b1, b2) = new_locked.lock_both::<B, _>(&data1, &data2);
623 assert_eq!(*b1, 15);
624 assert_eq!(*b2, 20);
625 }
626 {
627 let (a2, a1) = locked.lock_both::<A, _>(&data2, &data1);
628 assert_eq!(*a1, 5);
629 assert_eq!(*a2, 10);
630 }
631 }
632}