1use {fuchsia_sync as _, lock_api as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use lock_api::RawMutex;
11use std::{any, fmt};
12
13pub use fuchsia_sync::{
14 MappedMutexGuard, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
15};
16
17#[async_trait::async_trait(?Send)]
21pub trait AsyncUnlockable {
22 async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
26 where
27 F: AsyncFnOnce() -> U;
28}
29
30#[async_trait::async_trait(?Send)]
31impl<'a, T> crate::AsyncUnlockable for MutexGuard<'a, T> {
32 async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
33 where
34 F: AsyncFnOnce() -> U,
35 {
36 unsafe {
38 Self::mutex(s).raw().unlock();
39 }
40 scopeguard::defer!(
41 unsafe { Self::mutex(s).raw().lock() }
43 );
44 f().await
45 }
46}
47
48pub fn ordered_lock<'a, T>(
52 m1: &'a Mutex<T>,
53 m2: &'a Mutex<T>,
54) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
55 let ptr1: *const Mutex<T> = m1;
56 let ptr2: *const Mutex<T> = m2;
57 if ptr1 < ptr2 {
58 let g1 = m1.lock();
59 let g2 = m2.lock();
60 (g1, g2)
61 } else {
62 let g2 = m2.lock();
63 let g1 = m1.lock();
64 (g1, g2)
65 }
66}
67
68pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
71 let mut indexed_mutexes = mutexes.iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
73
74 indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
76
77 let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
79
80 guards.sort_by_key(|(i, _)| *i);
82
83 guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
84}
85
86pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
90 mutex: Mutex<T>,
91 _phantom: PhantomData<L>,
92}
93
94impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
95 fn default() -> Self {
96 Self { mutex: Default::default(), _phantom: Default::default() }
97 }
98}
99
100impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
103 }
104}
105
106impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
107 type Data = T;
108 type Guard<'a>
109 = MutexGuard<'a, T>
110 where
111 T: 'a,
112 L: 'a;
113 fn lock(&self) -> Self::Guard<'_> {
114 self.mutex.lock()
115 }
116}
117
118impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
119 pub const fn new(t: T) -> Self {
120 Self { mutex: Mutex::new(t), _phantom: PhantomData }
121 }
122
123 pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
124 where
125 P: LockBefore<L>,
126 {
127 locked.lock(self)
128 }
129
130 pub fn lock_and<'a, P>(
131 &'a self,
132 locked: &'a mut Locked<P>,
133 ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
134 where
135 P: LockBefore<L>,
136 {
137 locked.lock_and(self)
138 }
139}
140
141pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
144 locked: &'a mut Locked<P>,
145 m1: &'a OrderedMutex<T, L>,
146 m2: &'a OrderedMutex<T, L>,
147) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, &'a mut Locked<L>)
148where
149 P: LockBefore<L>,
150{
151 locked.lock_both_and(m1, m2)
152}
153
154pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
158 rwlock: RwLock<T>,
159 _phantom: PhantomData<L>,
160}
161
162impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
163 fn default() -> Self {
164 Self { rwlock: Default::default(), _phantom: Default::default() }
165 }
166}
167
168impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
171 }
172}
173
174impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
175 type Data = T;
176 type ReadGuard<'a>
177 = RwLockReadGuard<'a, T>
178 where
179 T: 'a,
180 L: 'a;
181 type WriteGuard<'a>
182 = RwLockWriteGuard<'a, T>
183 where
184 T: 'a,
185 L: 'a;
186 fn read_lock(&self) -> Self::ReadGuard<'_> {
187 self.rwlock.read()
188 }
189 fn write_lock(&self) -> Self::WriteGuard<'_> {
190 self.rwlock.write()
191 }
192}
193
194impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
195 pub const fn new(t: T) -> Self {
196 Self { rwlock: RwLock::new(t), _phantom: PhantomData }
197 }
198
199 pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
200 where
201 P: LockBefore<L>,
202 {
203 locked.read_lock(self)
204 }
205
206 pub fn write<'a, P>(
207 &'a self,
208 locked: &'a mut Locked<P>,
209 ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
210 where
211 P: LockBefore<L>,
212 {
213 locked.write_lock(self)
214 }
215
216 pub fn read_and<'a, P>(
217 &'a self,
218 locked: &'a mut Locked<P>,
219 ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
220 where
221 P: LockBefore<L>,
222 {
223 locked.read_lock_and(self)
224 }
225
226 pub fn write_and<'a, P>(
227 &'a self,
228 locked: &'a mut Locked<P>,
229 ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
230 where
231 P: LockBefore<L>,
232 {
233 locked.write_lock_and(self)
234 }
235}
236
237#[cfg(test)]
238mod test {
239 use super::*;
240 use crate::Unlocked;
241
242 #[::fuchsia::test]
243 fn test_lock_ordering() {
244 let l1 = Mutex::new(1);
245 let l2 = Mutex::new(2);
246
247 {
248 let (g1, g2) = ordered_lock(&l1, &l2);
249 assert_eq!(*g1, 1);
250 assert_eq!(*g2, 2);
251 }
252 {
253 let (g2, g1) = ordered_lock(&l2, &l1);
254 assert_eq!(*g1, 1);
255 assert_eq!(*g2, 2);
256 }
257 }
258
259 #[::fuchsia::test]
260 fn test_vec_lock_ordering() {
261 let l1 = Mutex::new(1);
262 let l0 = Mutex::new(0);
263 let l2 = Mutex::new(2);
264
265 {
266 let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
267 assert_eq!(*guards[0], 0);
268 assert_eq!(*guards[1], 1);
269 assert_eq!(*guards[2], 2);
270 }
271 {
272 let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
273 assert_eq!(*guards[0], 2);
274 assert_eq!(*guards[1], 1);
275 assert_eq!(*guards[2], 0);
276 }
277 }
278
279 mod lock_levels {
280 use crate::{LockAfter, UninterruptibleLock, Unlocked};
284 use lock_ordering_macro::lock_ordering;
285 lock_ordering! {
286 Unlocked => A,
287 A => B,
288 B => C,
289 Unlocked => D,
290 D => E,
291 E => F,
292 }
293
294 impl LockAfter<UninterruptibleLock> for A {}
295 impl LockAfter<UninterruptibleLock> for B {}
296 impl LockAfter<UninterruptibleLock> for C {}
297 impl LockAfter<UninterruptibleLock> for D {}
298 impl LockAfter<UninterruptibleLock> for E {}
299 impl LockAfter<UninterruptibleLock> for F {}
300 }
301
302 use lock_levels::{A, B, C, D, E, F};
303
304 #[test]
305 fn test_ordered_mutex() {
306 let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
307 let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
308 let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
309
310 #[allow(
311 clippy::undocumented_unsafe_blocks,
312 reason = "Force documented unsafe blocks in Starnix"
313 )]
314 let locked = unsafe { Unlocked::new() };
315
316 let (a_data, mut next_locked) = a.lock_and(locked);
317 let c_data = c.lock(&mut next_locked);
318
319 assert_eq!(&*a_data, &15);
324 assert_eq!(&*c_data, &45);
325 }
326 #[test]
327 fn test_ordered_rwlock() {
328 let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
329 let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
330 let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
331
332 #[allow(
333 clippy::undocumented_unsafe_blocks,
334 reason = "Force documented unsafe blocks in Starnix"
335 )]
336 let locked = unsafe { Unlocked::new() };
337 {
338 let (d_data, mut next_locked) = d.write_and(locked);
339 let f_data = f.read(&mut next_locked);
340
341 assert_eq!(&*d_data, &15);
346 assert_eq!(&*f_data, &45);
347 }
348 {
349 let (d_data, mut next_locked) = d.read_and(locked);
350 let f_data = f.write(&mut next_locked);
351
352 assert_eq!(&*d_data, &15);
357 assert_eq!(&*f_data, &45);
358 }
359 }
360
361 #[test]
362 fn test_lock_both() {
363 let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
364 let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
365 #[allow(
366 clippy::undocumented_unsafe_blocks,
367 reason = "Force documented unsafe blocks in Starnix"
368 )]
369 let locked = unsafe { Unlocked::new() };
370 {
371 let (a1_data, a2_data, _) = lock_both(locked, &a1, &a2);
372 assert_eq!(&*a1_data, &15);
373 assert_eq!(&*a2_data, &30);
374 }
375 {
376 let (a2_data, a1_data, _) = lock_both(locked, &a2, &a1);
377 assert_eq!(&*a1_data, &15);
378 assert_eq!(&*a2_data, &30);
379 }
380 }
381}