_core_rustc_static/
waiter_list.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cell::Cell;
6use std::sync::{Condvar, MutexGuard};
7use std::time::Duration;
8
9/// A collection of waiting threads, each identified by the address being waited for.
10pub struct WaiterList {
11    /// The first node in the linked list (null if the list is empty).
12    head: *const Node,
13}
14
15// SAFETY: Nodes' `next` pointers are managed by the WaiterList methods and nodes' contents are safe
16// to access from any thread.
17unsafe impl Send for WaiterList {}
18
19struct Node {
20    // The address being waited for.
21    waited_address: u64,
22
23    // A condition variable to wake up the waiting thread.
24    condvar: Condvar,
25
26    /// Nodes are stored in an intrusive singly-linked list.
27    ///
28    /// - None means that this node is not currently part of a list.
29    /// - Some(ptr) means that this node is part of a list and its successor is ptr (which might be
30    ///   a null pointer, if at the end of the list).
31    next: Cell<Option<*const Node>>,
32}
33
34impl Default for WaiterList {
35    fn default() -> Self {
36        Self { head: std::ptr::null() }
37    }
38}
39
40impl WaiterList {
41    /// Waits until a notification for the given address is received.
42    ///
43    /// The caller must provide the MutexGuard that protects the WaiterList and a function to obtain
44    /// it starting from the mutex's inner type.
45    ///
46    /// Note: The mutex may be released and re-acquired multiple times while waiting.
47    ///
48    /// If the requested timeout is exceeded, this function will panic.
49    pub fn wait<'a, T>(
50        mut guard: MutexGuard<'a, T>,
51        get_waiter_list: impl Fn(&mut T) -> &mut WaiterList,
52        address: u64,
53        panic_after_timeout: Duration,
54    ) -> MutexGuard<'a, T> {
55        let node = std::pin::pin!(Node {
56            waited_address: address,
57            condvar: Condvar::new(),
58            next: Cell::new(None)
59        });
60
61        // Insert it at the head of the list.
62        let old_head = std::mem::replace(&mut get_waiter_list(&mut guard).head, &*node);
63        node.next.set(Some(old_head));
64
65        // When the address is notified, the node will be removed from the list and its condvar
66        // notified.
67        let (guard, timeout_result) = node
68            .condvar
69            .wait_timeout_while(guard, panic_after_timeout, |_| node.next.get().is_some())
70            .unwrap();
71        if timeout_result.timed_out() {
72            panic!("WaiterList::wait timed out while waiting for address {:#x}", address);
73        }
74
75        guard
76    }
77
78    /// Notifies the first node that is waiting on the given address.
79    pub fn notify_one(&mut self, address: u64) {
80        let mut prev: Option<&Node> = None;
81        let mut it: *const Node = self.head;
82
83        // SAFETY: If `it` is not null, the object it points to is alive, because it's part of this
84        // list.
85        while let Some(node) = unsafe { it.as_ref() } {
86            if node.waited_address == address {
87                // Remove the node.
88                let next = node.next.take().expect("node must be in the list");
89                if let Some(prev) = prev {
90                    prev.next.set(Some(next));
91                } else {
92                    self.head = next;
93                }
94
95                // Notify the waiting thread.
96                node.condvar.notify_one();
97
98                return;
99            } else {
100                // Advance to the next node.
101                it = node.next.get().expect("node must be in the list");
102                prev = Some(node);
103            }
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::sync::{Arc, Mutex};
112
113    #[test]
114    fn test_notify_waited_address() {
115        let waiter_list = Arc::new(Mutex::new(WaiterList::default()));
116        let guard = waiter_list.lock().unwrap();
117
118        let _notifier_thread = {
119            let waiter_list = waiter_list.clone();
120            std::thread::spawn(move || {
121                let mut guard = waiter_list.lock().unwrap();
122                guard.notify_one(0x1234);
123            })
124        };
125
126        // Wait for an address that is notified by the above thread as soon as the `wait` function
127        // internally releases the lock.
128        let guard =
129            WaiterList::wait(guard, |waiter_list| waiter_list, 0x1234, Duration::from_secs(10));
130        drop(guard);
131    }
132
133    #[test]
134    fn test_notify_empty_list() {
135        let waiter_list = Mutex::new(WaiterList::default());
136        let mut guard = waiter_list.lock().unwrap();
137        guard.notify_one(0x1234);
138    }
139
140    #[test]
141    #[should_panic(expected = "timed out while waiting for address 0x1234")]
142    fn test_wait_timeout() {
143        let waiter_list = Mutex::new(WaiterList::default());
144        let guard = waiter_list.lock().unwrap();
145
146        // Wait for an address that is never notified. This will panic due to timeout.
147        let guard =
148            WaiterList::wait(guard, |waiter_list| waiter_list, 0x1234, Duration::from_millis(1));
149        drop(guard);
150    }
151}