oneshot_sync/
lib.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides a synchronous thread-safe oneshot channel.
6
7use std::ops::DerefMut as _;
8
9#[cfg(loom)]
10pub(crate) use loom::sync::{Arc, Condvar, Mutex};
11#[cfg(not(loom))]
12pub(crate) use std::sync::{Arc, Condvar, Mutex};
13
14#[derive(PartialEq, Eq, Debug, Clone, Copy)]
15enum CompletionState<T> {
16    /// The oneshot has not been sent.
17    Pending,
18    /// The oneshot sender was dropped.
19    Dropped,
20    /// The oneshot has been sent, and the receiver has not taken the sent
21    /// value.
22    Done(T),
23    /// The receiver has taken the sent value.
24    Finished,
25}
26
27impl<T> CompletionState<T> {
28    fn assert_pending_and_replace(&mut self, next_completion_state: CompletionState<T>) {
29        // We'd prefer to use `assert_matches` here, but `T` doesn't necessarily
30        // impl Debug.
31        match std::mem::replace(self, next_completion_state) {
32            CompletionState::Pending => (),
33            CompletionState::Done(_) => panic!("CompletionState should be Pending, not Done"),
34            CompletionState::Dropped => panic!("CompletionState should be Pending, not Dropped"),
35            CompletionState::Finished => panic!("CompletionState should be Pending, not Finished"),
36        }
37    }
38}
39
40#[derive(Debug)]
41struct CompletionSignal<T> {
42    done: Mutex<CompletionState<T>>,
43    condvar: Condvar,
44}
45
46/// Creates a pair of sender and corresponding receiver.
47pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
48    let inner = Arc::new(CompletionSignal {
49        done: Mutex::new(CompletionState::Pending),
50        condvar: Condvar::new(),
51    });
52    (Sender { inner: Some(inner.clone()) }, Receiver { inner })
53}
54
55/// The sender side of a pair of sender and receiver.
56#[derive(Debug)]
57pub struct Sender<T> {
58    inner: Option<Arc<CompletionSignal<T>>>,
59}
60
61impl<T> Sender<T> {
62    /// Sends the `done_value` to the peer, waking it up if it is blocked on
63    /// receiving.
64    pub fn send(mut self, done_value: T) {
65        // `inner` is present unless we have already sent or been dropped, which
66        // is impossible.
67        let inner = self.inner.take().expect("should be present");
68        std::mem::forget(self);
69
70        let CompletionSignal { done, condvar } = inner.as_ref();
71        let mut done = done.lock().unwrap();
72        done.assert_pending_and_replace(CompletionState::Done(done_value));
73        condvar.notify_one();
74    }
75}
76
77impl<T> Drop for Sender<T> {
78    fn drop(&mut self) {
79        let Self { inner } = self;
80        // `inner` is present unless we have already sent or been dropped, which
81        // is impossible.
82        let inner = inner.take().expect("should be present");
83        let CompletionSignal { done, condvar } = inner.as_ref();
84        let mut done = done.lock().unwrap();
85        done.assert_pending_and_replace(CompletionState::Dropped);
86        condvar.notify_one();
87    }
88}
89
90/// Used to block on a completion signal being observed.
91#[derive(Debug)]
92pub struct Receiver<T> {
93    inner: Arc<CompletionSignal<T>>,
94}
95
96/// Error indicating the sender was dropped without ever signaling the receiver.
97#[derive(Debug, PartialEq, Eq)]
98pub struct SenderDroppedError;
99
100impl<T> Receiver<T> {
101    /// Blocks until the sender sends or is dropped.
102    pub fn receive(self) -> Result<T, SenderDroppedError> {
103        let Self { inner } = self;
104        let CompletionSignal { done, condvar } = inner.as_ref();
105
106        // Loop until we're no longer pending.
107        let mut done = done.lock().expect("should not be poisoned");
108        loop {
109            done = match done.deref_mut() {
110                CompletionState::Pending => condvar.wait(done).expect("should not be poisoned"),
111                CompletionState::Dropped => return Err(SenderDroppedError),
112                CompletionState::Finished => panic!("should not have finished during receive"),
113                CompletionState::Done(_) => break,
114            }
115        }
116
117        let done = std::mem::replace(done.deref_mut(), CompletionState::Finished);
118        match done {
119            CompletionState::Done(done) => Ok(done),
120            _ => unreachable!(),
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use loom::sync::atomic::AtomicU8;
129    use loom::thread;
130    use std::sync::atomic::Ordering;
131
132    #[test]
133    fn waits_for_completion() {
134        // Spawn a sender and a receiver thread that both want to write to the
135        // same atomic, where the receiver thread should wait until the completion
136        // signal before writing to it.
137
138        loom::model(|| {
139            let (sender, receiver) = channel();
140            let data_to_race_on = Arc::new(AtomicU8::new(0));
141
142            let sender_thread = thread::spawn({
143                let data_to_race_on = data_to_race_on.clone();
144                move || {
145                    let previous_value = data_to_race_on.swap(1, Ordering::Relaxed);
146                    assert_eq!(previous_value, 0);
147                    sender.send(());
148                }
149            });
150
151            let receiver_thread = thread::spawn({
152                move || {
153                    receiver.receive().expect("sender should not be dropped");
154                    let previous_value = data_to_race_on.swap(2, Ordering::Relaxed);
155                    assert_eq!(previous_value, 1);
156                }
157            });
158
159            sender_thread.join().expect("should succeed");
160            receiver_thread.join().expect("should succeed");
161        });
162    }
163
164    #[test]
165    fn observes_drop() {
166        loom::model(|| {
167            let (sender, receiver) = channel::<()>();
168            let data_to_race_on = Arc::new(AtomicU8::new(0));
169
170            let sender_thread = thread::spawn({
171                let data_to_race_on = data_to_race_on.clone();
172                move || {
173                    let previous_value = data_to_race_on.swap(1, Ordering::Relaxed);
174                    assert_eq!(previous_value, 0);
175                    drop(sender);
176                }
177            });
178
179            let receiver_thread = thread::spawn({
180                move || {
181                    let SenderDroppedError =
182                        receiver.receive().expect_err("sender should be dropped");
183                    let previous_value = data_to_race_on.swap(2, Ordering::Relaxed);
184                    assert_eq!(previous_value, 1);
185                }
186            });
187
188            sender_thread.join().expect("should succeed");
189            receiver_thread.join().expect("should succeed");
190        });
191    }
192}