1use crate::task::CurrentTask;
6use starnix_uapi::errors::Errno;
7
8use core::marker::PhantomData;
9
10use starnix_sync::{InterruptibleEvent, LockBefore, Locked, Mutex};
11use std::collections::VecDeque;
12use std::sync::Arc;
13
14use lock_api as _;
15
16#[cfg(any(test, debug_assertions))]
17use lock_api::RawRwLock;
18
19#[derive(Debug)]
20pub struct RwQueue<L> {
21 inner: Mutex<RwQueueInner>,
22 _phantom: PhantomData<L>,
23
24 #[cfg(any(test, debug_assertions))]
26 tracer: tracer::MutexTracer,
27}
28
29impl<L> RwQueue<L> {
30 fn read_internal(&self, current_task: &CurrentTask) -> Result<(), Errno> {
35 #[cfg(any(test, debug_assertions))]
36 self.tracer.lock_shared();
37
38 let mut inner = self.inner.lock();
39
40 if !inner.try_read() {
41 let event = InterruptibleEvent::new();
42 let guard = event.begin_wait();
43
44 inner.waiters.push_back(Waiter::Reader(event.clone()));
45
46 std::mem::drop(inner);
47
48 current_task.block_until(guard, zx::MonotonicInstant::INFINITE).map_err(|e| {
49 self.inner.lock().remove_waiter(&event);
50 e
51 })?;
52 }
53 Ok(())
54 }
55
56 pub fn read_and<'a, P>(
57 &'a self,
58 locked: &'a mut Locked<P>,
59 current_task: &CurrentTask,
60 ) -> Result<(RwQueueReadGuard<'a, L>, &'a mut Locked<L>), Errno>
61 where
62 P: LockBefore<L>,
63 {
64 self.read_internal(current_task)?;
65
66 let new_locked = locked.cast_locked::<L>();
67
68 Ok((RwQueueReadGuard { queue: self }, new_locked))
69 }
70
71 pub fn write_and<'a, P>(
72 &'a self,
73 locked: &'a mut Locked<P>,
74 current_task: &CurrentTask,
75 ) -> Result<(RwQueueWriteGuard<'a, L>, &'a mut Locked<L>), Errno>
76 where
77 P: LockBefore<L>,
78 {
79 #[cfg(any(test, debug_assertions))]
80 self.tracer.lock_exclusive();
81
82 let mut inner = self.inner.lock();
83
84 if !inner.try_write() {
85 let event = InterruptibleEvent::new();
86 let guard = event.begin_wait();
87
88 inner.waiters.push_back(Waiter::Writer(event.clone()));
89
90 std::mem::drop(inner);
91
92 current_task.block_until(guard, zx::MonotonicInstant::INFINITE).map_err(|e| {
93 self.inner.lock().remove_waiter(&event);
94 e
95 })?;
96 }
97
98 let new_locked = locked.cast_locked::<L>();
99 Ok((RwQueueWriteGuard { queue: self }, new_locked))
100 }
101
102 pub fn read<'a, P>(
103 &'a self,
104 locked: &'a mut Locked<P>,
105 current_task: &CurrentTask,
106 ) -> Result<RwQueueReadGuard<'a, L>, Errno>
107 where
108 P: LockBefore<L>,
109 {
110 self.read_and(locked, current_task).map(|(g, _)| g)
111 }
112
113 pub fn write<'a, P>(
114 &'a self,
115 locked: &'a mut Locked<P>,
116 current_task: &CurrentTask,
117 ) -> Result<RwQueueWriteGuard<'_, L>, Errno>
118 where
119 P: LockBefore<L>,
120 {
121 self.write_and(locked, current_task).map(|(g, _)| g)
122 }
123
124 #[cfg(any(test, debug_assertions))]
126 pub fn read_for_lock_ordering<'a, P>(
127 &'a self,
128 locked: &'a mut Locked<P>,
129 ) -> (RwQueueReadGuard<'a, L>, &'a mut Locked<L>)
130 where
131 P: LockBefore<L>,
132 {
133 #[cfg(any(test, debug_assertions))]
134 self.tracer.lock_shared();
135
136 assert!(self.inner.lock().try_read(), "Cannot fail to acquire a read for lock ordering.");
137 let new_locked = locked.cast_locked::<L>();
138
139 (RwQueueReadGuard { queue: self }, new_locked)
140 }
141
142 fn unlock_read(&self) {
143 self.inner.lock().unlock_read();
144
145 #[allow(
146 clippy::undocumented_unsafe_blocks,
147 reason = "Force documented unsafe blocks in Starnix"
148 )]
149 #[cfg(any(test, debug_assertions))]
150 unsafe {
151 self.tracer.unlock_shared();
152 }
153 }
154
155 fn unlock_write(&self) {
156 self.inner.lock().unlock_write();
157
158 #[allow(
159 clippy::undocumented_unsafe_blocks,
160 reason = "Force documented unsafe blocks in Starnix"
161 )]
162 #[cfg(any(test, debug_assertions))]
163 unsafe {
164 self.tracer.unlock_exclusive();
165 }
166 }
167}
168
169impl<L> Default for RwQueue<L> {
170 fn default() -> Self {
171 Self {
172 inner: Default::default(),
173 #[cfg(any(test, debug_assertions))]
174 tracer: Default::default(),
175 _phantom: Default::default(),
176 }
177 }
178}
179
180const READY: usize = 0;
182
183const WRITER: usize = 0b01;
185
186const READER: usize = 0b10;
188
189fn has_writer(state: usize) -> bool {
191 state & WRITER != 0
192}
193
194fn has_reader(state: usize) -> bool {
196 state >= READER
197}
198
199fn debug_assert_consistent(state: usize) {
200 debug_assert!(!has_writer(state) || !has_reader(state));
201}
202
203#[derive(Debug, Clone)]
204enum Waiter {
205 Reader(Arc<InterruptibleEvent>),
206 Writer(Arc<InterruptibleEvent>),
207}
208
209#[derive(Debug, Default)]
210struct RwQueueInner {
211 state: usize,
215
216 waiters: VecDeque<Waiter>,
218}
219
220impl RwQueueInner {
221 fn has_waiters(&self) -> bool {
222 !self.waiters.is_empty()
223 }
224
225 fn try_read(&mut self) -> bool {
226 debug_assert_consistent(self.state);
227 if !has_writer(self.state) && !self.has_waiters() {
228 if let Some(new_state) = self.state.checked_add(READER) {
229 self.state = new_state;
230 return true;
231 }
232 }
233 false
234 }
235
236 fn try_write(&mut self) -> bool {
237 debug_assert_consistent(self.state);
238 if self.state == READY && !self.has_waiters() {
239 self.state += WRITER;
240 true
241 } else {
242 false
243 }
244 }
245
246 fn unlock_read(&mut self) {
247 debug_assert!(has_reader(self.state) && !has_writer(self.state));
248 self.state -= READER;
249
250 if !has_reader(self.state) && self.has_waiters() {
251 self.notify_next();
252 }
253 }
254
255 fn unlock_write(&mut self) {
256 debug_assert!(has_writer(self.state) && !has_reader(self.state));
257 self.state -= WRITER;
258
259 if self.has_waiters() {
260 self.notify_next();
261 }
262 }
263
264 fn notify_next(&mut self) {
265 while let Some(waiter) = self.waiters.front() {
266 match waiter {
267 Waiter::Reader(reader) => {
268 if has_writer(self.state) {
269 return;
270 }
271 let Some(new_state) = self.state.checked_add(READER) else {
275 return;
276 };
277 self.state = new_state;
278 reader.notify();
279 }
280 Waiter::Writer(writer) => {
281 if has_reader(self.state) || has_writer(self.state) {
282 return;
283 }
284 self.state += WRITER;
287 writer.notify();
288 }
289 }
290 self.waiters.pop_front();
291 }
292 debug_assert_consistent(self.state);
293 }
294
295 fn remove_waiter(&mut self, event: &Arc<InterruptibleEvent>) {
296 self.waiters.retain(|waiter| {
297 let (Waiter::Reader(other) | Waiter::Writer(other)) = waiter;
298 !Arc::ptr_eq(event, other)
299 });
300 }
301}
302
303pub struct RwQueueReadGuard<'a, L> {
304 queue: &'a RwQueue<L>,
305}
306
307impl<'a, L> Drop for RwQueueReadGuard<'a, L> {
308 fn drop(&mut self) {
309 self.queue.unlock_read();
310 }
311}
312
313pub struct RwQueueWriteGuard<'a, L> {
314 queue: &'a RwQueue<L>,
315}
316
317impl<'a, L> Drop for RwQueueWriteGuard<'a, L> {
318 fn drop(&mut self) {
319 self.queue.unlock_write();
320 }
321}
322
323#[cfg(any(test, debug_assertions))]
324mod tracer {
325
326 #[derive(Debug, Default)]
327 pub struct FakeRwLock {}
328
329 #[allow(
330 clippy::undocumented_unsafe_blocks,
331 reason = "Force documented unsafe blocks in Starnix"
332 )]
333 unsafe impl lock_api::RawRwLock for FakeRwLock {
334 const INIT: Self = Self {};
335
336 type GuardMarker = lock_api::GuardNoSend;
337
338 fn lock_shared(&self) {}
339 fn try_lock_shared(&self) -> bool {
340 false
341 }
342 unsafe fn unlock_shared(&self) {}
343
344 fn lock_exclusive(&self) {}
345 fn try_lock_exclusive(&self) -> bool {
346 false
347 }
348 unsafe fn unlock_exclusive(&self) {}
349
350 fn is_locked(&self) -> bool {
351 false
352 }
353 }
354
355 pub type MutexTracer = tracing_mutex::lockapi::TracingWrapper<FakeRwLock>;
357}
358
359#[cfg(not(any(test, debug_assertions)))]
362use tracing_mutex as _;
363
364#[cfg(test)]
365mod test {
366 use super::*;
367 use crate::task::Kernel;
368 use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
369 use crate::testing::*;
370 use futures::executor::block_on;
371 use futures::future::join_all;
372 use starnix_sync::{Unlocked, lock_ordering};
373 use std::future::Future;
374 use std::pin::Pin;
375 use std::sync::Barrier;
376 use std::sync::atomic::{AtomicUsize, Ordering};
377
378 #[::fuchsia::test]
379 fn test_remove_from_queue() {
380 let mut inner = RwQueueInner::default();
381 let event1 = InterruptibleEvent::new();
382 let event2 = InterruptibleEvent::new();
383 let event3 = InterruptibleEvent::new();
384 inner.waiters.push_back(Waiter::Writer(event1.clone()));
385 inner.waiters.push_back(Waiter::Writer(event2.clone()));
386 inner.waiters.push_back(Waiter::Writer(event3.clone()));
387
388 inner.remove_waiter(&event2);
389
390 let waiter = inner.waiters.pop_front().expect("should have a waiter");
391 let Waiter::Writer(event) = waiter else {
392 unreachable!();
393 };
394 assert!(Arc::ptr_eq(&event1, &event));
395
396 let waiter = inner.waiters.pop_front().expect("should have a waiter");
397 let Waiter::Writer(event) = waiter else {
398 unreachable!();
399 };
400 assert!(Arc::ptr_eq(&event3, &event));
401
402 assert!(inner.waiters.is_empty());
403 }
404
405 #[::fuchsia::test]
406 async fn test_write_and_read() {
407 lock_ordering! {
408 Unlocked => TestLevel
409 }
410
411 spawn_kernel_and_run(async |locked, current_task| {
412 let queue = RwQueue::<TestLevel>::default();
413 let read_guard1 = queue.read(locked, current_task).expect("shouldn't be interrupted");
414 std::mem::drop(read_guard1);
415
416 let write_guard = queue.write(locked, current_task).expect("shouldn't be interrupted");
417 std::mem::drop(write_guard);
418
419 let read_guard2 = queue.read(locked, current_task).expect("shouldn't be interrupted");
420 std::mem::drop(read_guard2);
421 })
422 .await;
423 }
424
425 #[::fuchsia::test]
426 async fn test_read_in_parallel() {
427 spawn_kernel_and_run(async |_, current_task| {
428 let kernel = current_task.kernel();
429 lock_ordering! {
430 Unlocked => TestLevel
431 }
432 struct Info {
433 barrier: Barrier,
434 queue: RwQueue<TestLevel>,
435 }
436
437 let info =
438 Arc::new(Info { barrier: Barrier::new(2), queue: RwQueue::<TestLevel>::default() });
439
440 let info1 = Arc::clone(&info);
441 let closure1 = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
442 let guard =
443 info1.queue.read(locked, current_task).expect("shouldn't be interrupted");
444 info1.barrier.wait();
445 std::mem::drop(guard);
446 };
447 let (thread1, req) =
448 SpawnRequestBuilder::new().with_sync_closure(closure1).build_with_async_result();
449 kernel.kthreads.spawner().spawn_from_request(req);
450
451 let info2 = Arc::clone(&info);
452 let closure2 = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
453 let guard =
454 info2.queue.read(locked, current_task).expect("shouldn't be interrupted");
455 info2.barrier.wait();
456 std::mem::drop(guard);
457 };
458 let (thread2, req) =
459 SpawnRequestBuilder::new().with_sync_closure(closure2).build_with_async_result();
460 kernel.kthreads.spawner().spawn_from_request(req);
461
462 block_on(async {
463 thread1.await.expect("failed to join thread");
464 thread2.await.expect("failed to join thread");
465 });
466 })
467 .await;
468 }
469
470 lock_ordering! {
471 Unlocked => A
472 }
473 struct State {
474 queue: RwQueue<A>,
475 gate: Barrier,
476 writer_count: AtomicUsize,
477 reader_count: AtomicUsize,
478 }
479
480 impl State {
481 fn new(n: usize) -> State {
482 State {
483 queue: Default::default(),
484 gate: Barrier::new(n),
485 writer_count: Default::default(),
486 reader_count: Default::default(),
487 }
488 }
489
490 fn spawn_writer(
491 state: Arc<Self>,
492 kernel: Arc<Kernel>,
493 count: usize,
494 ) -> Pin<Box<dyn Future<Output = Result<(), Errno>> + Send>> {
495 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
496 state.gate.wait();
497 for _ in 0..count {
498 let guard =
499 state.queue.write(locked, current_task).expect("shouldn't be interrupted");
500 let writer_count = state.writer_count.fetch_add(1, Ordering::Acquire) + 1;
501 let reader_count = state.reader_count.load(Ordering::Acquire);
502 state.writer_count.fetch_sub(1, Ordering::Release);
503 std::mem::drop(guard);
504 assert_eq!(writer_count, 1, "More than one writer held the lock at once.");
505 assert_eq!(
506 reader_count, 0,
507 "A reader and writer held the lock at the same time."
508 );
509 }
510 };
511 let (result, req) =
512 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
513 kernel.kthreads.spawner().spawn_from_request(req);
514 Box::pin(result)
515 }
516
517 fn spawn_reader(
518 state: Arc<Self>,
519 kernel: Arc<Kernel>,
520 count: usize,
521 ) -> Pin<Box<dyn Future<Output = Result<(), Errno>> + Send>> {
522 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
523 state.gate.wait();
524 for _ in 0..count {
525 let guard =
526 state.queue.read(locked, current_task).expect("shouldn't be interrupted");
527 let reader_count = state.reader_count.fetch_add(1, Ordering::Acquire) + 1;
528 let writer_count = state.writer_count.load(Ordering::Acquire);
529 state.reader_count.fetch_sub(1, Ordering::Release);
530 std::mem::drop(guard);
531 assert_eq!(
532 writer_count, 0,
533 "A reader and writer held the lock at the same time."
534 );
535 assert!(reader_count > 0, "A reader held the lock without being counted.");
536 }
537 };
538 let (result, req) =
539 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
540 kernel.kthreads.spawner().spawn_from_request(req);
541 Box::pin(result)
542 }
543 }
544
545 #[::fuchsia::test]
546 async fn test_thundering_reads_and_writes() {
547 spawn_kernel_and_run(async |_, current_task| {
548 let kernel = current_task.kernel();
549 const THREAD_PAIRS: usize = 10;
550
551 let state = Arc::new(State::new(THREAD_PAIRS * 2));
552 let mut threads = vec![];
553 for _ in 0..THREAD_PAIRS {
554 threads.push(State::spawn_writer(Arc::clone(&state), kernel.clone(), 100));
555 threads.push(State::spawn_reader(Arc::clone(&state), kernel.clone(), 100));
556 }
557
558 block_on(join_all(threads)).into_iter().for_each(|r| r.expect("failed to join thread"));
559 })
560 .await;
561 }
562}