fragile/
fragile.rs

1use std::cmp;
2use std::fmt;
3use std::mem;
4use std::mem::MaybeUninit;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7use crate::errors::InvalidThreadAccess;
8
9fn next_thread_id() -> usize {
10    static mut COUNTER: AtomicUsize = AtomicUsize::new(0);
11    unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
12}
13
14pub(crate) fn get_thread_id() -> usize {
15    thread_local!(static THREAD_ID: usize = next_thread_id());
16    THREAD_ID.with(|&x| x)
17}
18
19/// A `Fragile<T>` wraps a non sendable `T` to be safely send to other threads.
20///
21/// Once the value has been wrapped it can be sent to other threads but access
22/// to the value on those threads will fail.
23///
24/// If the value needs destruction and the fragile wrapper is on another thread
25/// the destructor will panic.  Alternatively you can use `Sticky<T>` which is
26/// not going to panic but might temporarily leak the value.
27pub struct Fragile<T> {
28    value: MaybeUninit<Box<T>>,
29    thread_id: usize,
30}
31
32impl<T> Fragile<T> {
33    /// Creates a new `Fragile` wrapping a `value`.
34    ///
35    /// The value that is moved into the `Fragile` can be non `Send` and
36    /// will be anchored to the thread that created the object.  If the
37    /// fragile wrapper type ends up being send from thread to thread
38    /// only the original thread can interact with the value.
39    pub fn new(value: T) -> Self {
40        Fragile {
41            value: MaybeUninit::new(Box::new(value)),
42            thread_id: get_thread_id(),
43        }
44    }
45
46    /// Returns `true` if the access is valid.
47    ///
48    /// This will be `false` if the value was sent to another thread.
49    pub fn is_valid(&self) -> bool {
50        get_thread_id() == self.thread_id
51    }
52
53    #[inline(always)]
54    fn assert_thread(&self) {
55        if !self.is_valid() {
56            panic!("trying to access wrapped value in fragile container from incorrect thread.");
57        }
58    }
59
60    /// Consumes the `Fragile`, returning the wrapped value.
61    ///
62    /// # Panics
63    ///
64    /// Panics if called from a different thread than the one where the
65    /// original value was created.
66    pub fn into_inner(mut self) -> T {
67        self.assert_thread();
68        unsafe {
69            let rv = mem::replace(&mut self.value, MaybeUninit::uninit());
70            mem::forget(self);
71            *rv.assume_init()
72        }
73    }
74
75    /// Consumes the `Fragile`, returning the wrapped value if successful.
76    ///
77    /// The wrapped value is returned if this is called from the same thread
78    /// as the one where the original value was created, otherwise the
79    /// `Fragile` is returned as `Err(self)`.
80    pub fn try_into_inner(self) -> Result<T, Self> {
81        if get_thread_id() == self.thread_id {
82            Ok(self.into_inner())
83        } else {
84            Err(self)
85        }
86    }
87
88    /// Immutably borrows the wrapped value.
89    ///
90    /// # Panics
91    ///
92    /// Panics if the calling thread is not the one that wrapped the value.
93    /// For a non-panicking variant, use [`try_get`](#method.try_get`).
94    pub fn get(&self) -> &T {
95        self.assert_thread();
96        unsafe { &*self.value.as_ptr() }
97    }
98
99    /// Mutably borrows the wrapped value.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the calling thread is not the one that wrapped the value.
104    /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
105    pub fn get_mut(&mut self) -> &mut T {
106        self.assert_thread();
107        unsafe { &mut *self.value.as_mut_ptr() }
108    }
109
110    /// Tries to immutably borrow the wrapped value.
111    ///
112    /// Returns `None` if the calling thread is not the one that wrapped the value.
113    pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
114        if get_thread_id() == self.thread_id {
115            unsafe { Ok(&*self.value.as_ptr()) }
116        } else {
117            Err(InvalidThreadAccess)
118        }
119    }
120
121    /// Tries to mutably borrow the wrapped value.
122    ///
123    /// Returns `None` if the calling thread is not the one that wrapped the value.
124    pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
125        if get_thread_id() == self.thread_id {
126            unsafe { Ok(&mut *self.value.as_mut_ptr()) }
127        } else {
128            Err(InvalidThreadAccess)
129        }
130    }
131}
132
133impl<T> Drop for Fragile<T> {
134    fn drop(&mut self) {
135        if mem::needs_drop::<T>() {
136            if get_thread_id() == self.thread_id {
137                unsafe {
138                    let rv = mem::replace(&mut self.value, MaybeUninit::uninit());
139                    rv.assume_init();
140                }
141            } else {
142                panic!("destructor of fragile object ran on wrong thread");
143            }
144        }
145    }
146}
147
148impl<T> From<T> for Fragile<T> {
149    #[inline]
150    fn from(t: T) -> Fragile<T> {
151        Fragile::new(t)
152    }
153}
154
155impl<T: Clone> Clone for Fragile<T> {
156    #[inline]
157    fn clone(&self) -> Fragile<T> {
158        Fragile::new(self.get().clone())
159    }
160}
161
162impl<T: Default> Default for Fragile<T> {
163    #[inline]
164    fn default() -> Fragile<T> {
165        Fragile::new(T::default())
166    }
167}
168
169impl<T: PartialEq> PartialEq for Fragile<T> {
170    #[inline]
171    fn eq(&self, other: &Fragile<T>) -> bool {
172        *self.get() == *other.get()
173    }
174}
175
176impl<T: Eq> Eq for Fragile<T> {}
177
178impl<T: PartialOrd> PartialOrd for Fragile<T> {
179    #[inline]
180    fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
181        self.get().partial_cmp(&*other.get())
182    }
183
184    #[inline]
185    fn lt(&self, other: &Fragile<T>) -> bool {
186        *self.get() < *other.get()
187    }
188
189    #[inline]
190    fn le(&self, other: &Fragile<T>) -> bool {
191        *self.get() <= *other.get()
192    }
193
194    #[inline]
195    fn gt(&self, other: &Fragile<T>) -> bool {
196        *self.get() > *other.get()
197    }
198
199    #[inline]
200    fn ge(&self, other: &Fragile<T>) -> bool {
201        *self.get() >= *other.get()
202    }
203}
204
205impl<T: Ord> Ord for Fragile<T> {
206    #[inline]
207    fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
208        self.get().cmp(&*other.get())
209    }
210}
211
212impl<T: fmt::Display> fmt::Display for Fragile<T> {
213    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
214        fmt::Display::fmt(self.get(), f)
215    }
216}
217
218impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
219    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
220        match self.try_get() {
221            Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
222            Err(..) => {
223                struct InvalidPlaceholder;
224                impl fmt::Debug for InvalidPlaceholder {
225                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226                        f.write_str("<invalid thread>")
227                    }
228                }
229
230                f.debug_struct("Fragile")
231                    .field("value", &InvalidPlaceholder)
232                    .finish()
233            }
234        }
235    }
236}
237
238// this type is sync because access can only ever happy from the same thread
239// that created it originally.  All other threads will be able to safely
240// call some basic operations on the reference and they will fail.
241unsafe impl<T> Sync for Fragile<T> {}
242
243// The entire point of this type is to be Send
244unsafe impl<T> Send for Fragile<T> {}
245
246#[test]
247fn test_basic() {
248    use std::thread;
249    let val = Fragile::new(true);
250    assert_eq!(val.to_string(), "true");
251    assert_eq!(val.get(), &true);
252    assert!(val.try_get().is_ok());
253    thread::spawn(move || {
254        assert!(val.try_get().is_err());
255    })
256    .join()
257    .unwrap();
258}
259
260#[test]
261fn test_mut() {
262    let mut val = Fragile::new(true);
263    *val.get_mut() = false;
264    assert_eq!(val.to_string(), "false");
265    assert_eq!(val.get(), &false);
266}
267
268#[test]
269#[should_panic]
270fn test_access_other_thread() {
271    use std::thread;
272    let val = Fragile::new(true);
273    thread::spawn(move || {
274        val.get();
275    })
276    .join()
277    .unwrap();
278}
279
280#[test]
281fn test_noop_drop_elsewhere() {
282    use std::thread;
283    let val = Fragile::new(true);
284    thread::spawn(move || {
285        // force the move
286        val.try_get().ok();
287    })
288    .join()
289    .unwrap();
290}
291
292#[test]
293fn test_panic_on_drop_elsewhere() {
294    use std::sync::atomic::{AtomicBool, Ordering};
295    use std::sync::Arc;
296    use std::thread;
297    let was_called = Arc::new(AtomicBool::new(false));
298    struct X(Arc<AtomicBool>);
299    impl Drop for X {
300        fn drop(&mut self) {
301            self.0.store(true, Ordering::SeqCst);
302        }
303    }
304    let val = Fragile::new(X(was_called.clone()));
305    assert!(thread::spawn(move || {
306        val.try_get().ok();
307    })
308    .join()
309    .is_err());
310    assert_eq!(was_called.load(Ordering::SeqCst), false);
311}
312
313#[test]
314fn test_rc_sending() {
315    use std::rc::Rc;
316    use std::sync::mpsc::channel;
317    use std::thread;
318
319    let val = Fragile::new(Rc::new(true));
320    let (tx, rx) = channel();
321
322    let thread = thread::spawn(move || {
323        assert!(val.try_get().is_err());
324        let here = val;
325        tx.send(here).unwrap();
326    });
327
328    let rv = rx.recv().unwrap();
329    assert!(**rv.get());
330
331    thread.join().unwrap();
332}