1use std::cmp;
2use std::fmt;
3use std::mem;
4use std::mem::MaybeUninit;
5use std::sync::atomic::{AtomicUsize, Ordering};
67use crate::errors::InvalidThreadAccess;
89fn next_thread_id() -> usize {
10static mut COUNTER: AtomicUsize = AtomicUsize::new(0);
11unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
12}
1314pub(crate) fn get_thread_id() -> usize {
15thread_local!(static THREAD_ID: usize = next_thread_id());
16 THREAD_ID.with(|&x| x)
17}
1819/// A `Fragile<T>` wraps a non sendable `T` to be safely send to other threads.
20///
21/// Once the value has been wrapped it can be sent to other threads but access
22/// to the value on those threads will fail.
23///
24/// If the value needs destruction and the fragile wrapper is on another thread
25/// the destructor will panic. Alternatively you can use `Sticky<T>` which is
26/// not going to panic but might temporarily leak the value.
27pub struct Fragile<T> {
28 value: MaybeUninit<Box<T>>,
29 thread_id: usize,
30}
3132impl<T> Fragile<T> {
33/// Creates a new `Fragile` wrapping a `value`.
34 ///
35 /// The value that is moved into the `Fragile` can be non `Send` and
36 /// will be anchored to the thread that created the object. If the
37 /// fragile wrapper type ends up being send from thread to thread
38 /// only the original thread can interact with the value.
39pub fn new(value: T) -> Self {
40 Fragile {
41 value: MaybeUninit::new(Box::new(value)),
42 thread_id: get_thread_id(),
43 }
44 }
4546/// Returns `true` if the access is valid.
47 ///
48 /// This will be `false` if the value was sent to another thread.
49pub fn is_valid(&self) -> bool {
50 get_thread_id() == self.thread_id
51 }
5253#[inline(always)]
54fn assert_thread(&self) {
55if !self.is_valid() {
56panic!("trying to access wrapped value in fragile container from incorrect thread.");
57 }
58 }
5960/// Consumes the `Fragile`, returning the wrapped value.
61 ///
62 /// # Panics
63 ///
64 /// Panics if called from a different thread than the one where the
65 /// original value was created.
66pub fn into_inner(mut self) -> T {
67self.assert_thread();
68unsafe {
69let rv = mem::replace(&mut self.value, MaybeUninit::uninit());
70 mem::forget(self);
71*rv.assume_init()
72 }
73 }
7475/// Consumes the `Fragile`, returning the wrapped value if successful.
76 ///
77 /// The wrapped value is returned if this is called from the same thread
78 /// as the one where the original value was created, otherwise the
79 /// `Fragile` is returned as `Err(self)`.
80pub fn try_into_inner(self) -> Result<T, Self> {
81if get_thread_id() == self.thread_id {
82Ok(self.into_inner())
83 } else {
84Err(self)
85 }
86 }
8788/// Immutably borrows the wrapped value.
89 ///
90 /// # Panics
91 ///
92 /// Panics if the calling thread is not the one that wrapped the value.
93 /// For a non-panicking variant, use [`try_get`](#method.try_get`).
94pub fn get(&self) -> &T {
95self.assert_thread();
96unsafe { &*self.value.as_ptr() }
97 }
9899/// Mutably borrows the wrapped value.
100 ///
101 /// # Panics
102 ///
103 /// Panics if the calling thread is not the one that wrapped the value.
104 /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
105pub fn get_mut(&mut self) -> &mut T {
106self.assert_thread();
107unsafe { &mut *self.value.as_mut_ptr() }
108 }
109110/// Tries to immutably borrow the wrapped value.
111 ///
112 /// Returns `None` if the calling thread is not the one that wrapped the value.
113pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
114if get_thread_id() == self.thread_id {
115unsafe { Ok(&*self.value.as_ptr()) }
116 } else {
117Err(InvalidThreadAccess)
118 }
119 }
120121/// Tries to mutably borrow the wrapped value.
122 ///
123 /// Returns `None` if the calling thread is not the one that wrapped the value.
124pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
125if get_thread_id() == self.thread_id {
126unsafe { Ok(&mut *self.value.as_mut_ptr()) }
127 } else {
128Err(InvalidThreadAccess)
129 }
130 }
131}
132133impl<T> Drop for Fragile<T> {
134fn drop(&mut self) {
135if mem::needs_drop::<T>() {
136if get_thread_id() == self.thread_id {
137unsafe {
138let rv = mem::replace(&mut self.value, MaybeUninit::uninit());
139 rv.assume_init();
140 }
141 } else {
142panic!("destructor of fragile object ran on wrong thread");
143 }
144 }
145 }
146}
147148impl<T> From<T> for Fragile<T> {
149#[inline]
150fn from(t: T) -> Fragile<T> {
151 Fragile::new(t)
152 }
153}
154155impl<T: Clone> Clone for Fragile<T> {
156#[inline]
157fn clone(&self) -> Fragile<T> {
158 Fragile::new(self.get().clone())
159 }
160}
161162impl<T: Default> Default for Fragile<T> {
163#[inline]
164fn default() -> Fragile<T> {
165 Fragile::new(T::default())
166 }
167}
168169impl<T: PartialEq> PartialEq for Fragile<T> {
170#[inline]
171fn eq(&self, other: &Fragile<T>) -> bool {
172*self.get() == *other.get()
173 }
174}
175176impl<T: Eq> Eq for Fragile<T> {}
177178impl<T: PartialOrd> PartialOrd for Fragile<T> {
179#[inline]
180fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
181self.get().partial_cmp(&*other.get())
182 }
183184#[inline]
185fn lt(&self, other: &Fragile<T>) -> bool {
186*self.get() < *other.get()
187 }
188189#[inline]
190fn le(&self, other: &Fragile<T>) -> bool {
191*self.get() <= *other.get()
192 }
193194#[inline]
195fn gt(&self, other: &Fragile<T>) -> bool {
196*self.get() > *other.get()
197 }
198199#[inline]
200fn ge(&self, other: &Fragile<T>) -> bool {
201*self.get() >= *other.get()
202 }
203}
204205impl<T: Ord> Ord for Fragile<T> {
206#[inline]
207fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
208self.get().cmp(&*other.get())
209 }
210}
211212impl<T: fmt::Display> fmt::Display for Fragile<T> {
213fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
214 fmt::Display::fmt(self.get(), f)
215 }
216}
217218impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
219fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
220match self.try_get() {
221Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
222Err(..) => {
223struct InvalidPlaceholder;
224impl fmt::Debug for InvalidPlaceholder {
225fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226 f.write_str("<invalid thread>")
227 }
228 }
229230 f.debug_struct("Fragile")
231 .field("value", &InvalidPlaceholder)
232 .finish()
233 }
234 }
235 }
236}
237238// this type is sync because access can only ever happy from the same thread
239// that created it originally. All other threads will be able to safely
240// call some basic operations on the reference and they will fail.
241unsafe impl<T> Sync for Fragile<T> {}
242243// The entire point of this type is to be Send
244unsafe impl<T> Send for Fragile<T> {}
245246#[test]
247fn test_basic() {
248use std::thread;
249let val = Fragile::new(true);
250assert_eq!(val.to_string(), "true");
251assert_eq!(val.get(), &true);
252assert!(val.try_get().is_ok());
253 thread::spawn(move || {
254assert!(val.try_get().is_err());
255 })
256 .join()
257 .unwrap();
258}
259260#[test]
261fn test_mut() {
262let mut val = Fragile::new(true);
263*val.get_mut() = false;
264assert_eq!(val.to_string(), "false");
265assert_eq!(val.get(), &false);
266}
267268#[test]
269#[should_panic]
270fn test_access_other_thread() {
271use std::thread;
272let val = Fragile::new(true);
273 thread::spawn(move || {
274 val.get();
275 })
276 .join()
277 .unwrap();
278}
279280#[test]
281fn test_noop_drop_elsewhere() {
282use std::thread;
283let val = Fragile::new(true);
284 thread::spawn(move || {
285// force the move
286val.try_get().ok();
287 })
288 .join()
289 .unwrap();
290}
291292#[test]
293fn test_panic_on_drop_elsewhere() {
294use std::sync::atomic::{AtomicBool, Ordering};
295use std::sync::Arc;
296use std::thread;
297let was_called = Arc::new(AtomicBool::new(false));
298struct X(Arc<AtomicBool>);
299impl Drop for X {
300fn drop(&mut self) {
301self.0.store(true, Ordering::SeqCst);
302 }
303 }
304let val = Fragile::new(X(was_called.clone()));
305assert!(thread::spawn(move || {
306 val.try_get().ok();
307 })
308 .join()
309 .is_err());
310assert_eq!(was_called.load(Ordering::SeqCst), false);
311}
312313#[test]
314fn test_rc_sending() {
315use std::rc::Rc;
316use std::sync::mpsc::channel;
317use std::thread;
318319let val = Fragile::new(Rc::new(true));
320let (tx, rx) = channel();
321322let thread = thread::spawn(move || {
323assert!(val.try_get().is_err());
324let here = val;
325 tx.send(here).unwrap();
326 });
327328let rv = rx.recv().unwrap();
329assert!(**rv.get());
330331 thread.join().unwrap();
332}