1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
45use std::cell::Cell;
6use std::sync::{Condvar, MutexGuard};
7use std::time::Duration;
89/// A collection of waiting threads, each identified by the address being waited for.
10pub struct WaiterList {
11/// The first node in the linked list (null if the list is empty).
12head: *const Node,
13}
1415// SAFETY: Nodes' `next` pointers are managed by the WaiterList methods and nodes' contents are safe
16// to access from any thread.
17unsafe impl Send for WaiterList {}
1819struct Node {
20// The address being waited for.
21waited_address: u64,
2223// A condition variable to wake up the waiting thread.
24condvar: Condvar,
2526/// Nodes are stored in an intrusive singly-linked list.
27 ///
28 /// - None means that this node is not currently part of a list.
29 /// - Some(ptr) means that this node is part of a list and its successor is ptr (which might be
30 /// a null pointer, if at the end of the list).
31next: Cell<Option<*const Node>>,
32}
3334impl Default for WaiterList {
35fn default() -> Self {
36Self { head: std::ptr::null() }
37 }
38}
3940impl WaiterList {
41/// Waits until a notification for the given address is received.
42 ///
43 /// The caller must provide the MutexGuard that protects the WaiterList and a function to obtain
44 /// it starting from the mutex's inner type.
45 ///
46 /// Note: The mutex may be released and re-acquired multiple times while waiting.
47 ///
48 /// If the requested timeout is exceeded, this function will panic.
49pub fn wait<'a, T>(
50mut guard: MutexGuard<'a, T>,
51 get_waiter_list: impl Fn(&mut T) -> &mut WaiterList,
52 address: u64,
53 panic_after_timeout: Duration,
54 ) -> MutexGuard<'a, T> {
55let node = std::pin::pin!(Node {
56 waited_address: address,
57 condvar: Condvar::new(),
58 next: Cell::new(None)
59 });
6061// Insert it at the head of the list.
62let old_head = std::mem::replace(&mut get_waiter_list(&mut guard).head, &*node);
63 node.next.set(Some(old_head));
6465// When the address is notified, the node will be removed from the list and its condvar
66 // notified.
67let (guard, timeout_result) = node
68 .condvar
69 .wait_timeout_while(guard, panic_after_timeout, |_| node.next.get().is_some())
70 .unwrap();
71if timeout_result.timed_out() {
72panic!("WaiterList::wait timed out while waiting for address {:#x}", address);
73 }
7475 guard
76 }
7778/// Notifies the first node that is waiting on the given address.
79pub fn notify_one(&mut self, address: u64) {
80let mut prev: Option<&Node> = None;
81let mut it: *const Node = self.head;
8283// SAFETY: If `it` is not null, the object it points to is alive, because it's part of this
84 // list.
85while let Some(node) = unsafe { it.as_ref() } {
86if node.waited_address == address {
87// Remove the node.
88let next = node.next.take().expect("node must be in the list");
89if let Some(prev) = prev {
90 prev.next.set(Some(next));
91 } else {
92self.head = next;
93 }
9495// Notify the waiting thread.
96node.condvar.notify_one();
9798return;
99 } else {
100// Advance to the next node.
101it = node.next.get().expect("node must be in the list");
102 prev = Some(node);
103 }
104 }
105 }
106}
107108#[cfg(test)]
109mod tests {
110use super::*;
111use std::sync::{Arc, Mutex};
112113#[test]
114fn test_notify_waited_address() {
115let waiter_list = Arc::new(Mutex::new(WaiterList::default()));
116let guard = waiter_list.lock().unwrap();
117118let _notifier_thread = {
119let waiter_list = waiter_list.clone();
120 std::thread::spawn(move || {
121let mut guard = waiter_list.lock().unwrap();
122 guard.notify_one(0x1234);
123 })
124 };
125126// Wait for an address that is notified by the above thread as soon as the `wait` function
127 // internally releases the lock.
128let guard =
129 WaiterList::wait(guard, |waiter_list| waiter_list, 0x1234, Duration::from_secs(10));
130 drop(guard);
131 }
132133#[test]
134fn test_notify_empty_list() {
135let waiter_list = Mutex::new(WaiterList::default());
136let mut guard = waiter_list.lock().unwrap();
137 guard.notify_one(0x1234);
138 }
139140#[test]
141 #[should_panic(expected = "timed out while waiting for address 0x1234")]
142fn test_wait_timeout() {
143let waiter_list = Mutex::new(WaiterList::default());
144let guard = waiter_list.lock().unwrap();
145146// Wait for an address that is never notified. This will panic due to timeout.
147let guard =
148 WaiterList::wait(guard, |waiter_list| waiter_list, 0x1234, Duration::from_millis(1));
149 drop(guard);
150 }
151}