1use {fuchsia_sync as _, lock_api as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use lock_api::RawMutex;
11use std::{any, fmt};
12
13pub use fuchsia_sync::{
14 MappedMutexGuard, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
15};
16
17#[async_trait::async_trait(?Send)]
21pub trait AsyncUnlockable {
22 async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
26 where
27 F: AsyncFnOnce() -> U;
28}
29
30#[async_trait::async_trait(?Send)]
31impl<'a, T> crate::AsyncUnlockable for MutexGuard<'a, T> {
32 async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
33 where
34 F: AsyncFnOnce() -> U,
35 {
36 unsafe {
38 Self::mutex(s).raw().unlock();
39 }
40 scopeguard::defer!(
41 unsafe { Self::mutex(s).raw().lock() }
43 );
44 f().await
45 }
46}
47
48pub fn ordered_lock<'a, T>(
52 m1: &'a Mutex<T>,
53 m2: &'a Mutex<T>,
54) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
55 let ptr1: *const Mutex<T> = m1;
56 let ptr2: *const Mutex<T> = m2;
57 if ptr1 < ptr2 {
58 let g1 = m1.lock();
59 let g2 = m2.lock();
60 (g1, g2)
61 } else {
62 let g2 = m2.lock();
63 let g1 = m1.lock();
64 (g1, g2)
65 }
66}
67
68pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
71 let mut indexed_mutexes =
73 mutexes.into_iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
74
75 indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
77
78 let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
80
81 guards.sort_by_key(|(i, _)| *i);
83
84 guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
85}
86
87pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
91 mutex: Mutex<T>,
92 _phantom: PhantomData<L>,
93}
94
95impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
96 fn default() -> Self {
97 Self { mutex: Default::default(), _phantom: Default::default() }
98 }
99}
100
101impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
104 }
105}
106
107impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
108 type Data = T;
109 type Guard<'a>
110 = MutexGuard<'a, T>
111 where
112 T: 'a,
113 L: 'a;
114 fn lock(&self) -> Self::Guard<'_> {
115 self.mutex.lock()
116 }
117}
118
119impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
120 pub const fn new(t: T) -> Self {
121 Self { mutex: Mutex::new(t), _phantom: PhantomData }
122 }
123
124 pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
125 where
126 P: LockBefore<L>,
127 {
128 locked.lock(self)
129 }
130
131 pub fn lock_and<'a, P>(
132 &'a self,
133 locked: &'a mut Locked<P>,
134 ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
135 where
136 P: LockBefore<L>,
137 {
138 locked.lock_and(self)
139 }
140}
141
142pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
145 locked: &'a mut Locked<P>,
146 m1: &'a OrderedMutex<T, L>,
147 m2: &'a OrderedMutex<T, L>,
148) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, &'a mut Locked<L>)
149where
150 P: LockBefore<L>,
151{
152 locked.lock_both_and(m1, m2)
153}
154
155pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
159 rwlock: RwLock<T>,
160 _phantom: PhantomData<L>,
161}
162
163impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
164 fn default() -> Self {
165 Self { rwlock: Default::default(), _phantom: Default::default() }
166 }
167}
168
169impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
172 }
173}
174
175impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
176 type Data = T;
177 type ReadGuard<'a>
178 = RwLockReadGuard<'a, T>
179 where
180 T: 'a,
181 L: 'a;
182 type WriteGuard<'a>
183 = RwLockWriteGuard<'a, T>
184 where
185 T: 'a,
186 L: 'a;
187 fn read_lock(&self) -> Self::ReadGuard<'_> {
188 self.rwlock.read()
189 }
190 fn write_lock(&self) -> Self::WriteGuard<'_> {
191 self.rwlock.write()
192 }
193}
194
195impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
196 pub const fn new(t: T) -> Self {
197 Self { rwlock: RwLock::new(t), _phantom: PhantomData }
198 }
199
200 pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
201 where
202 P: LockBefore<L>,
203 {
204 locked.read_lock(self)
205 }
206
207 pub fn write<'a, P>(
208 &'a self,
209 locked: &'a mut Locked<P>,
210 ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
211 where
212 P: LockBefore<L>,
213 {
214 locked.write_lock(self)
215 }
216
217 pub fn read_and<'a, P>(
218 &'a self,
219 locked: &'a mut Locked<P>,
220 ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
221 where
222 P: LockBefore<L>,
223 {
224 locked.read_lock_and(self)
225 }
226
227 pub fn write_and<'a, P>(
228 &'a self,
229 locked: &'a mut Locked<P>,
230 ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
231 where
232 P: LockBefore<L>,
233 {
234 locked.write_lock_and(self)
235 }
236}
237
238#[cfg(test)]
239mod test {
240 use super::*;
241 use crate::Unlocked;
242
243 #[::fuchsia::test]
244 fn test_lock_ordering() {
245 let l1 = Mutex::new(1);
246 let l2 = Mutex::new(2);
247
248 {
249 let (g1, g2) = ordered_lock(&l1, &l2);
250 assert_eq!(*g1, 1);
251 assert_eq!(*g2, 2);
252 }
253 {
254 let (g2, g1) = ordered_lock(&l2, &l1);
255 assert_eq!(*g1, 1);
256 assert_eq!(*g2, 2);
257 }
258 }
259
260 #[::fuchsia::test]
261 fn test_vec_lock_ordering() {
262 let l1 = Mutex::new(1);
263 let l0 = Mutex::new(0);
264 let l2 = Mutex::new(2);
265
266 {
267 let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
268 assert_eq!(*guards[0], 0);
269 assert_eq!(*guards[1], 1);
270 assert_eq!(*guards[2], 2);
271 }
272 {
273 let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
274 assert_eq!(*guards[0], 2);
275 assert_eq!(*guards[1], 1);
276 assert_eq!(*guards[2], 0);
277 }
278 }
279
280 mod lock_levels {
281 use crate::{LockAfter, UninterruptibleLock, Unlocked};
285 use lock_ordering_macro::lock_ordering;
286 lock_ordering! {
287 Unlocked => A,
288 A => B,
289 B => C,
290 Unlocked => D,
291 D => E,
292 E => F,
293 }
294
295 impl LockAfter<UninterruptibleLock> for A {}
296 impl LockAfter<UninterruptibleLock> for B {}
297 impl LockAfter<UninterruptibleLock> for C {}
298 impl LockAfter<UninterruptibleLock> for D {}
299 impl LockAfter<UninterruptibleLock> for E {}
300 impl LockAfter<UninterruptibleLock> for F {}
301 }
302
303 use lock_levels::{A, B, C, D, E, F};
304
305 #[test]
306 fn test_ordered_mutex() {
307 let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
308 let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
309 let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
310
311 #[allow(
312 clippy::undocumented_unsafe_blocks,
313 reason = "Force documented unsafe blocks in Starnix"
314 )]
315 let locked = unsafe { Unlocked::new() };
316
317 let (a_data, mut next_locked) = a.lock_and(locked);
318 let c_data = c.lock(&mut next_locked);
319
320 assert_eq!(&*a_data, &15);
325 assert_eq!(&*c_data, &45);
326 }
327 #[test]
328 fn test_ordered_rwlock() {
329 let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
330 let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
331 let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
332
333 #[allow(
334 clippy::undocumented_unsafe_blocks,
335 reason = "Force documented unsafe blocks in Starnix"
336 )]
337 let locked = unsafe { Unlocked::new() };
338 {
339 let (d_data, mut next_locked) = d.write_and(locked);
340 let f_data = f.read(&mut next_locked);
341
342 assert_eq!(&*d_data, &15);
347 assert_eq!(&*f_data, &45);
348 }
349 {
350 let (d_data, mut next_locked) = d.read_and(locked);
351 let f_data = f.write(&mut next_locked);
352
353 assert_eq!(&*d_data, &15);
358 assert_eq!(&*f_data, &45);
359 }
360 }
361
362 #[test]
363 fn test_lock_both() {
364 let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
365 let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
366 #[allow(
367 clippy::undocumented_unsafe_blocks,
368 reason = "Force documented unsafe blocks in Starnix"
369 )]
370 let locked = unsafe { Unlocked::new() };
371 {
372 let (a1_data, a2_data, _) = lock_both(locked, &a1, &a2);
373 assert_eq!(&*a1_data, &15);
374 assert_eq!(&*a2_data, &30);
375 }
376 {
377 let (a2_data, a1_data, _) = lock_both(locked, &a2, &a1);
378 assert_eq!(&*a1_data, &15);
379 assert_eq!(&*a2_data, &30);
380 }
381 }
382}