starnix_core/vfs/
rw_queue.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::task::CurrentTask;
6use starnix_uapi::errors::Errno;
7
8use core::marker::PhantomData;
9
10use starnix_sync::{InterruptibleEvent, LockBefore, Locked, Mutex};
11use std::collections::VecDeque;
12use std::sync::Arc;
13
14use lock_api as _;
15
16#[cfg(any(test, debug_assertions))]
17use lock_api::RawRwLock;
18
19#[derive(Debug)]
20pub struct RwQueue<L> {
21    inner: Mutex<RwQueueInner>,
22    _phantom: PhantomData<L>,
23
24    // Used to inform our deadlock detector about the waiters in the queue.
25    #[cfg(any(test, debug_assertions))]
26    tracer: tracer::MutexTracer,
27}
28
29impl<L> RwQueue<L> {
30    // Acquires a read lock without checking lock ordering.
31    // TODO(https://fxbug.dev/333540469): This should be a part of the implementation
32    // of an OrderedRwLock. However, this requires that OrderedRwLock accepts the
33    // `read()` method that uses a context (in this case, `CurrentTask`).
34    fn read_internal(&self, current_task: &CurrentTask) -> Result<(), Errno> {
35        #[cfg(any(test, debug_assertions))]
36        self.tracer.lock_shared();
37
38        let mut inner = self.inner.lock();
39
40        if !inner.try_read() {
41            let event = InterruptibleEvent::new();
42            let guard = event.begin_wait();
43
44            inner.waiters.push_back(Waiter::Reader(event.clone()));
45
46            std::mem::drop(inner);
47
48            current_task.block_until(guard, zx::MonotonicInstant::INFINITE).map_err(|e| {
49                self.inner.lock().remove_waiter(&event);
50                e
51            })?;
52        }
53        Ok(())
54    }
55
56    pub fn read_and<'a, P>(
57        &'a self,
58        locked: &'a mut Locked<P>,
59        current_task: &CurrentTask,
60    ) -> Result<(RwQueueReadGuard<'a, L>, &'a mut Locked<L>), Errno>
61    where
62        P: LockBefore<L>,
63    {
64        self.read_internal(current_task)?;
65
66        let new_locked = locked.cast_locked::<L>();
67
68        Ok((RwQueueReadGuard { queue: self }, new_locked))
69    }
70
71    pub fn write_and<'a, P>(
72        &'a self,
73        locked: &'a mut Locked<P>,
74        current_task: &CurrentTask,
75    ) -> Result<(RwQueueWriteGuard<'a, L>, &'a mut Locked<L>), Errno>
76    where
77        P: LockBefore<L>,
78    {
79        #[cfg(any(test, debug_assertions))]
80        self.tracer.lock_exclusive();
81
82        let mut inner = self.inner.lock();
83
84        if !inner.try_write() {
85            let event = InterruptibleEvent::new();
86            let guard = event.begin_wait();
87
88            inner.waiters.push_back(Waiter::Writer(event.clone()));
89
90            std::mem::drop(inner);
91
92            current_task.block_until(guard, zx::MonotonicInstant::INFINITE).map_err(|e| {
93                self.inner.lock().remove_waiter(&event);
94                e
95            })?;
96        }
97
98        let new_locked = locked.cast_locked::<L>();
99        Ok((RwQueueWriteGuard { queue: self }, new_locked))
100    }
101
102    pub fn read<'a, P>(
103        &'a self,
104        locked: &'a mut Locked<P>,
105        current_task: &CurrentTask,
106    ) -> Result<RwQueueReadGuard<'a, L>, Errno>
107    where
108        P: LockBefore<L>,
109    {
110        self.read_and(locked, current_task).map(|(g, _)| g)
111    }
112
113    pub fn write<'a, P>(
114        &'a self,
115        locked: &'a mut Locked<P>,
116        current_task: &CurrentTask,
117    ) -> Result<RwQueueWriteGuard<'_, L>, Errno>
118    where
119        P: LockBefore<L>,
120    {
121        self.write_and(locked, current_task).map(|(g, _)| g)
122    }
123
124    /// Used to establish lock ordering.
125    #[cfg(any(test, debug_assertions))]
126    pub fn read_for_lock_ordering<'a, P>(
127        &'a self,
128        locked: &'a mut Locked<P>,
129    ) -> (RwQueueReadGuard<'a, L>, &'a mut Locked<L>)
130    where
131        P: LockBefore<L>,
132    {
133        #[cfg(any(test, debug_assertions))]
134        self.tracer.lock_shared();
135
136        assert!(self.inner.lock().try_read(), "Cannot fail to acquire a read for lock ordering.");
137        let new_locked = locked.cast_locked::<L>();
138
139        (RwQueueReadGuard { queue: self }, new_locked)
140    }
141
142    fn unlock_read(&self) {
143        self.inner.lock().unlock_read();
144
145        #[allow(
146            clippy::undocumented_unsafe_blocks,
147            reason = "Force documented unsafe blocks in Starnix"
148        )]
149        #[cfg(any(test, debug_assertions))]
150        unsafe {
151            self.tracer.unlock_shared();
152        }
153    }
154
155    fn unlock_write(&self) {
156        self.inner.lock().unlock_write();
157
158        #[allow(
159            clippy::undocumented_unsafe_blocks,
160            reason = "Force documented unsafe blocks in Starnix"
161        )]
162        #[cfg(any(test, debug_assertions))]
163        unsafe {
164            self.tracer.unlock_exclusive();
165        }
166    }
167}
168
169impl<L> Default for RwQueue<L> {
170    fn default() -> Self {
171        Self {
172            inner: Default::default(),
173            #[cfg(any(test, debug_assertions))]
174            tracer: Default::default(),
175            _phantom: Default::default(),
176        }
177    }
178}
179
180/// The queue is ready for any operation.
181const READY: usize = 0;
182
183/// The queue has exactly one writer.
184const WRITER: usize = 0b01;
185
186/// Each writer in the queue increments the state by this amount.
187const READER: usize = 0b10;
188
189/// A writer is currently running.
190fn has_writer(state: usize) -> bool {
191    state & WRITER != 0
192}
193
194/// At elast one reader is currently running.
195fn has_reader(state: usize) -> bool {
196    state >= READER
197}
198
199fn debug_assert_consistent(state: usize) {
200    debug_assert!(!has_writer(state) || !has_reader(state));
201}
202
203#[derive(Debug, Clone)]
204enum Waiter {
205    Reader(Arc<InterruptibleEvent>),
206    Writer(Arc<InterruptibleEvent>),
207}
208
209#[derive(Debug, Default)]
210struct RwQueueInner {
211    /// What operations are currently ongoing.
212    ///
213    /// See READY, READER, WRITER above for what these bits mean.
214    state: usize,
215
216    /// The operations that are waiting for the ongoing operations to complete.
217    waiters: VecDeque<Waiter>,
218}
219
220impl RwQueueInner {
221    fn has_waiters(&self) -> bool {
222        !self.waiters.is_empty()
223    }
224
225    fn try_read(&mut self) -> bool {
226        debug_assert_consistent(self.state);
227        if !has_writer(self.state) && !self.has_waiters() {
228            if let Some(new_state) = self.state.checked_add(READER) {
229                self.state = new_state;
230                return true;
231            }
232        }
233        false
234    }
235
236    fn try_write(&mut self) -> bool {
237        debug_assert_consistent(self.state);
238        if self.state == READY && !self.has_waiters() {
239            self.state += WRITER;
240            true
241        } else {
242            false
243        }
244    }
245
246    fn unlock_read(&mut self) {
247        debug_assert!(has_reader(self.state) && !has_writer(self.state));
248        self.state -= READER;
249
250        if !has_reader(self.state) && self.has_waiters() {
251            self.notify_next();
252        }
253    }
254
255    fn unlock_write(&mut self) {
256        debug_assert!(has_writer(self.state) && !has_reader(self.state));
257        self.state -= WRITER;
258
259        if self.has_waiters() {
260            self.notify_next();
261        }
262    }
263
264    fn notify_next(&mut self) {
265        while let Some(waiter) = self.waiters.front() {
266            match waiter {
267                Waiter::Reader(reader) => {
268                    if has_writer(self.state) {
269                        return;
270                    }
271                    // We need to use `checked_add` to ensure we do not
272                    // overflow the number of readers. If that happens, we just
273                    // need to wait for the enormous number of readers to finish.
274                    let Some(new_state) = self.state.checked_add(READER) else {
275                        return;
276                    };
277                    self.state = new_state;
278                    reader.notify();
279                }
280                Waiter::Writer(writer) => {
281                    if has_reader(self.state) || has_writer(self.state) {
282                        return;
283                    }
284                    // We can never overflow writers because we only let one
285                    // through at a time.
286                    self.state += WRITER;
287                    writer.notify();
288                }
289            }
290            self.waiters.pop_front();
291        }
292        debug_assert_consistent(self.state);
293    }
294
295    fn remove_waiter(&mut self, event: &Arc<InterruptibleEvent>) {
296        self.waiters.retain(|waiter| {
297            let (Waiter::Reader(other) | Waiter::Writer(other)) = waiter;
298            !Arc::ptr_eq(event, other)
299        });
300    }
301}
302
303pub struct RwQueueReadGuard<'a, L> {
304    queue: &'a RwQueue<L>,
305}
306
307impl<'a, L> Drop for RwQueueReadGuard<'a, L> {
308    fn drop(&mut self) {
309        self.queue.unlock_read();
310    }
311}
312
313pub struct RwQueueWriteGuard<'a, L> {
314    queue: &'a RwQueue<L>,
315}
316
317impl<'a, L> Drop for RwQueueWriteGuard<'a, L> {
318    fn drop(&mut self) {
319        self.queue.unlock_write();
320    }
321}
322
323#[cfg(any(test, debug_assertions))]
324mod tracer {
325
326    #[derive(Debug, Default)]
327    pub struct FakeRwLock {}
328
329    #[allow(
330        clippy::undocumented_unsafe_blocks,
331        reason = "Force documented unsafe blocks in Starnix"
332    )]
333    unsafe impl lock_api::RawRwLock for FakeRwLock {
334        const INIT: Self = Self {};
335
336        type GuardMarker = lock_api::GuardNoSend;
337
338        fn lock_shared(&self) {}
339        fn try_lock_shared(&self) -> bool {
340            false
341        }
342        unsafe fn unlock_shared(&self) {}
343
344        fn lock_exclusive(&self) {}
345        fn try_lock_exclusive(&self) -> bool {
346            false
347        }
348        unsafe fn unlock_exclusive(&self) {}
349
350        fn is_locked(&self) -> bool {
351            false
352        }
353    }
354
355    // We should replace this type with tracing_mutex::MutexId once that type is public.
356    pub type MutexTracer = tracing_mutex::lockapi::TracingWrapper<FakeRwLock>;
357}
358
359// We use tracing_mutex in tests and debug assertions, but we don't want to pull it in for
360// production.
361#[cfg(not(any(test, debug_assertions)))]
362use tracing_mutex as _;
363
364#[cfg(test)]
365mod test {
366    use super::*;
367    use crate::task::Kernel;
368    use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
369    use crate::testing::*;
370    use futures::executor::block_on;
371    use futures::future::join_all;
372    use starnix_sync::{Unlocked, lock_ordering};
373    use std::future::Future;
374    use std::pin::Pin;
375    use std::sync::Barrier;
376    use std::sync::atomic::{AtomicUsize, Ordering};
377
378    #[::fuchsia::test]
379    fn test_remove_from_queue() {
380        let mut inner = RwQueueInner::default();
381        let event1 = InterruptibleEvent::new();
382        let event2 = InterruptibleEvent::new();
383        let event3 = InterruptibleEvent::new();
384        inner.waiters.push_back(Waiter::Writer(event1.clone()));
385        inner.waiters.push_back(Waiter::Writer(event2.clone()));
386        inner.waiters.push_back(Waiter::Writer(event3.clone()));
387
388        inner.remove_waiter(&event2);
389
390        let waiter = inner.waiters.pop_front().expect("should have a waiter");
391        let Waiter::Writer(event) = waiter else {
392            unreachable!();
393        };
394        assert!(Arc::ptr_eq(&event1, &event));
395
396        let waiter = inner.waiters.pop_front().expect("should have a waiter");
397        let Waiter::Writer(event) = waiter else {
398            unreachable!();
399        };
400        assert!(Arc::ptr_eq(&event3, &event));
401
402        assert!(inner.waiters.is_empty());
403    }
404
405    #[::fuchsia::test]
406    async fn test_write_and_read() {
407        lock_ordering! {
408            Unlocked => TestLevel
409        }
410
411        spawn_kernel_and_run(async |locked, current_task| {
412            let queue = RwQueue::<TestLevel>::default();
413            let read_guard1 = queue.read(locked, current_task).expect("shouldn't be interrupted");
414            std::mem::drop(read_guard1);
415
416            let write_guard = queue.write(locked, current_task).expect("shouldn't be interrupted");
417            std::mem::drop(write_guard);
418
419            let read_guard2 = queue.read(locked, current_task).expect("shouldn't be interrupted");
420            std::mem::drop(read_guard2);
421        })
422        .await;
423    }
424
425    #[::fuchsia::test]
426    async fn test_read_in_parallel() {
427        spawn_kernel_and_run(async |_, current_task| {
428            let kernel = current_task.kernel();
429            lock_ordering! {
430                Unlocked => TestLevel
431            }
432            struct Info {
433                barrier: Barrier,
434                queue: RwQueue<TestLevel>,
435            }
436
437            let info =
438                Arc::new(Info { barrier: Barrier::new(2), queue: RwQueue::<TestLevel>::default() });
439
440            let info1 = Arc::clone(&info);
441            let closure1 = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
442                let guard =
443                    info1.queue.read(locked, current_task).expect("shouldn't be interrupted");
444                info1.barrier.wait();
445                std::mem::drop(guard);
446            };
447            let (thread1, req) =
448                SpawnRequestBuilder::new().with_sync_closure(closure1).build_with_async_result();
449            kernel.kthreads.spawner().spawn_from_request(req);
450
451            let info2 = Arc::clone(&info);
452            let closure2 = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
453                let guard =
454                    info2.queue.read(locked, current_task).expect("shouldn't be interrupted");
455                info2.barrier.wait();
456                std::mem::drop(guard);
457            };
458            let (thread2, req) =
459                SpawnRequestBuilder::new().with_sync_closure(closure2).build_with_async_result();
460            kernel.kthreads.spawner().spawn_from_request(req);
461
462            block_on(async {
463                thread1.await.expect("failed to join thread");
464                thread2.await.expect("failed to join thread");
465            });
466        })
467        .await;
468    }
469
470    lock_ordering! {
471        Unlocked => A
472    }
473    struct State {
474        queue: RwQueue<A>,
475        gate: Barrier,
476        writer_count: AtomicUsize,
477        reader_count: AtomicUsize,
478    }
479
480    impl State {
481        fn new(n: usize) -> State {
482            State {
483                queue: Default::default(),
484                gate: Barrier::new(n),
485                writer_count: Default::default(),
486                reader_count: Default::default(),
487            }
488        }
489
490        fn spawn_writer(
491            state: Arc<Self>,
492            kernel: Arc<Kernel>,
493            count: usize,
494        ) -> Pin<Box<dyn Future<Output = Result<(), Errno>> + Send>> {
495            let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
496                state.gate.wait();
497                for _ in 0..count {
498                    let guard =
499                        state.queue.write(locked, current_task).expect("shouldn't be interrupted");
500                    let writer_count = state.writer_count.fetch_add(1, Ordering::Acquire) + 1;
501                    let reader_count = state.reader_count.load(Ordering::Acquire);
502                    state.writer_count.fetch_sub(1, Ordering::Release);
503                    std::mem::drop(guard);
504                    assert_eq!(writer_count, 1, "More than one writer held the lock at once.");
505                    assert_eq!(
506                        reader_count, 0,
507                        "A reader and writer held the lock at the same time."
508                    );
509                }
510            };
511            let (result, req) =
512                SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
513            kernel.kthreads.spawner().spawn_from_request(req);
514            Box::pin(result)
515        }
516
517        fn spawn_reader(
518            state: Arc<Self>,
519            kernel: Arc<Kernel>,
520            count: usize,
521        ) -> Pin<Box<dyn Future<Output = Result<(), Errno>> + Send>> {
522            let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
523                state.gate.wait();
524                for _ in 0..count {
525                    let guard =
526                        state.queue.read(locked, current_task).expect("shouldn't be interrupted");
527                    let reader_count = state.reader_count.fetch_add(1, Ordering::Acquire) + 1;
528                    let writer_count = state.writer_count.load(Ordering::Acquire);
529                    state.reader_count.fetch_sub(1, Ordering::Release);
530                    std::mem::drop(guard);
531                    assert_eq!(
532                        writer_count, 0,
533                        "A reader and writer held the lock at the same time."
534                    );
535                    assert!(reader_count > 0, "A reader held the lock without being counted.");
536                }
537            };
538            let (result, req) =
539                SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
540            kernel.kthreads.spawner().spawn_from_request(req);
541            Box::pin(result)
542        }
543    }
544
545    #[::fuchsia::test]
546    async fn test_thundering_reads_and_writes() {
547        spawn_kernel_and_run(async |_, current_task| {
548            let kernel = current_task.kernel();
549            const THREAD_PAIRS: usize = 10;
550
551            let state = Arc::new(State::new(THREAD_PAIRS * 2));
552            let mut threads = vec![];
553            for _ in 0..THREAD_PAIRS {
554                threads.push(State::spawn_writer(Arc::clone(&state), kernel.clone(), 100));
555                threads.push(State::spawn_reader(Arc::clone(&state), kernel.clone(), 100));
556            }
557
558            block_on(join_all(threads)).into_iter().for_each(|r| r.expect("failed to join thread"));
559        })
560        .await;
561    }
562}