oneshot_sync/
lib.rs
1use std::ops::DerefMut as _;
8
9#[cfg(loom)]
10pub(crate) use loom::sync::{Arc, Condvar, Mutex};
11#[cfg(not(loom))]
12pub(crate) use std::sync::{Arc, Condvar, Mutex};
13
14#[derive(PartialEq, Eq, Debug, Clone, Copy)]
15enum CompletionState<T> {
16 Pending,
18 Dropped,
20 Done(T),
23 Finished,
25}
26
27impl<T> CompletionState<T> {
28 fn assert_pending_and_replace(&mut self, next_completion_state: CompletionState<T>) {
29 match std::mem::replace(self, next_completion_state) {
32 CompletionState::Pending => (),
33 CompletionState::Done(_) => panic!("CompletionState should be Pending, not Done"),
34 CompletionState::Dropped => panic!("CompletionState should be Pending, not Dropped"),
35 CompletionState::Finished => panic!("CompletionState should be Pending, not Finished"),
36 }
37 }
38}
39
40#[derive(Debug)]
41struct CompletionSignal<T> {
42 done: Mutex<CompletionState<T>>,
43 condvar: Condvar,
44}
45
46pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
48 let inner = Arc::new(CompletionSignal {
49 done: Mutex::new(CompletionState::Pending),
50 condvar: Condvar::new(),
51 });
52 (Sender { inner: Some(inner.clone()) }, Receiver { inner })
53}
54
55#[derive(Debug)]
57pub struct Sender<T> {
58 inner: Option<Arc<CompletionSignal<T>>>,
59}
60
61impl<T> Sender<T> {
62 pub fn send(mut self, done_value: T) {
65 let inner = self.inner.take().expect("should be present");
68 std::mem::forget(self);
69
70 let CompletionSignal { done, condvar } = inner.as_ref();
71 let mut done = done.lock().unwrap();
72 done.assert_pending_and_replace(CompletionState::Done(done_value));
73 condvar.notify_one();
74 }
75}
76
77impl<T> Drop for Sender<T> {
78 fn drop(&mut self) {
79 let Self { inner } = self;
80 let inner = inner.take().expect("should be present");
83 let CompletionSignal { done, condvar } = inner.as_ref();
84 let mut done = done.lock().unwrap();
85 done.assert_pending_and_replace(CompletionState::Dropped);
86 condvar.notify_one();
87 }
88}
89
90#[derive(Debug)]
92pub struct Receiver<T> {
93 inner: Arc<CompletionSignal<T>>,
94}
95
96#[derive(Debug, PartialEq, Eq)]
98pub struct SenderDroppedError;
99
100impl<T> Receiver<T> {
101 pub fn receive(self) -> Result<T, SenderDroppedError> {
103 let Self { inner } = self;
104 let CompletionSignal { done, condvar } = inner.as_ref();
105
106 let mut done = done.lock().expect("should not be poisoned");
108 loop {
109 done = match done.deref_mut() {
110 CompletionState::Pending => condvar.wait(done).expect("should not be poisoned"),
111 CompletionState::Dropped => return Err(SenderDroppedError),
112 CompletionState::Finished => panic!("should not have finished during receive"),
113 CompletionState::Done(_) => break,
114 }
115 }
116
117 let done = std::mem::replace(done.deref_mut(), CompletionState::Finished);
118 match done {
119 CompletionState::Done(done) => Ok(done),
120 _ => unreachable!(),
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use loom::sync::atomic::AtomicU8;
129 use loom::thread;
130 use std::sync::atomic::Ordering;
131
132 #[test]
133 fn waits_for_completion() {
134 loom::model(|| {
139 let (sender, receiver) = channel();
140 let data_to_race_on = Arc::new(AtomicU8::new(0));
141
142 let sender_thread = thread::spawn({
143 let data_to_race_on = data_to_race_on.clone();
144 move || {
145 let previous_value = data_to_race_on.swap(1, Ordering::Relaxed);
146 assert_eq!(previous_value, 0);
147 sender.send(());
148 }
149 });
150
151 let receiver_thread = thread::spawn({
152 move || {
153 receiver.receive().expect("sender should not be dropped");
154 let previous_value = data_to_race_on.swap(2, Ordering::Relaxed);
155 assert_eq!(previous_value, 1);
156 }
157 });
158
159 sender_thread.join().expect("should succeed");
160 receiver_thread.join().expect("should succeed");
161 });
162 }
163
164 #[test]
165 fn observes_drop() {
166 loom::model(|| {
167 let (sender, receiver) = channel::<()>();
168 let data_to_race_on = Arc::new(AtomicU8::new(0));
169
170 let sender_thread = thread::spawn({
171 let data_to_race_on = data_to_race_on.clone();
172 move || {
173 let previous_value = data_to_race_on.swap(1, Ordering::Relaxed);
174 assert_eq!(previous_value, 0);
175 drop(sender);
176 }
177 });
178
179 let receiver_thread = thread::spawn({
180 move || {
181 let SenderDroppedError =
182 receiver.receive().expect_err("sender should be dropped");
183 let previous_value = data_to_race_on.swap(2, Ordering::Relaxed);
184 assert_eq!(previous_value, 1);
185 }
186 });
187
188 sender_thread.join().expect("should succeed");
189 receiver_thread.join().expect("should succeed");
190 });
191 }
192}