fragile/
fragile.rs

1use std::cmp;
2use std::fmt;
3use std::mem;
4use std::num::NonZeroUsize;
5
6use crate::errors::InvalidThreadAccess;
7use crate::thread_id;
8use std::mem::ManuallyDrop;
9
10/// A [`Fragile<T>`] wraps a non sendable `T` to be safely send to other threads.
11///
12/// Once the value has been wrapped it can be sent to other threads but access
13/// to the value on those threads will fail.
14///
15/// If the value needs destruction and the fragile wrapper is on another thread
16/// the destructor will panic.  Alternatively you can use
17/// [`Sticky`](crate::Sticky) which is not going to panic but might temporarily
18/// leak the value.
19pub struct Fragile<T> {
20    // ManuallyDrop is necessary because we need to move out of here without running the
21    // Drop code in functions like `into_inner`.
22    value: ManuallyDrop<T>,
23    thread_id: NonZeroUsize,
24}
25
26impl<T> Fragile<T> {
27    /// Creates a new [`Fragile`] wrapping a `value`.
28    ///
29    /// The value that is moved into the [`Fragile`] can be non `Send` and
30    /// will be anchored to the thread that created the object.  If the
31    /// fragile wrapper type ends up being send from thread to thread
32    /// only the original thread can interact with the value.
33    pub fn new(value: T) -> Self {
34        Fragile {
35            value: ManuallyDrop::new(value),
36            thread_id: thread_id::get(),
37        }
38    }
39
40    /// Returns `true` if the access is valid.
41    ///
42    /// This will be `false` if the value was sent to another thread.
43    pub fn is_valid(&self) -> bool {
44        thread_id::get() == self.thread_id
45    }
46
47    #[inline(always)]
48    fn assert_thread(&self) {
49        if !self.is_valid() {
50            panic!("trying to access wrapped value in fragile container from incorrect thread.");
51        }
52    }
53
54    /// Consumes the `Fragile`, returning the wrapped value.
55    ///
56    /// # Panics
57    ///
58    /// Panics if called from a different thread than the one where the
59    /// original value was created.
60    pub fn into_inner(self) -> T {
61        self.assert_thread();
62
63        let mut this = ManuallyDrop::new(self);
64
65        // SAFETY: `this` is not accessed beyond this point, and because it's in a ManuallyDrop its
66        // destructor is not run.
67        unsafe { ManuallyDrop::take(&mut this.value) }
68    }
69
70    /// Consumes the `Fragile`, returning the wrapped value if successful.
71    ///
72    /// The wrapped value is returned if this is called from the same thread
73    /// as the one where the original value was created, otherwise the
74    /// [`Fragile`] is returned as `Err(self)`.
75    pub fn try_into_inner(self) -> Result<T, Self> {
76        if thread_id::get() == self.thread_id {
77            Ok(self.into_inner())
78        } else {
79            Err(self)
80        }
81    }
82
83    /// Immutably borrows the wrapped value.
84    ///
85    /// # Panics
86    ///
87    /// Panics if the calling thread is not the one that wrapped the value.
88    /// For a non-panicking variant, use [`try_get`](Self::try_get).
89    pub fn get(&self) -> &T {
90        self.assert_thread();
91        &self.value
92    }
93
94    /// Mutably borrows the wrapped value.
95    ///
96    /// # Panics
97    ///
98    /// Panics if the calling thread is not the one that wrapped the value.
99    /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut).
100    pub fn get_mut(&mut self) -> &mut T {
101        self.assert_thread();
102        &mut self.value
103    }
104
105    /// Tries to immutably borrow the wrapped value.
106    ///
107    /// Returns `None` if the calling thread is not the one that wrapped the value.
108    pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
109        if thread_id::get() == self.thread_id {
110            Ok(&*self.value)
111        } else {
112            Err(InvalidThreadAccess)
113        }
114    }
115
116    /// Tries to mutably borrow the wrapped value.
117    ///
118    /// Returns `None` if the calling thread is not the one that wrapped the value.
119    pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
120        if thread_id::get() == self.thread_id {
121            Ok(&mut *self.value)
122        } else {
123            Err(InvalidThreadAccess)
124        }
125    }
126}
127
128impl<T> Drop for Fragile<T> {
129    fn drop(&mut self) {
130        if mem::needs_drop::<T>() {
131            if thread_id::get() == self.thread_id {
132                // SAFETY: `ManuallyDrop::drop` cannot be called after this point.
133                unsafe { ManuallyDrop::drop(&mut self.value) };
134            } else {
135                panic!("destructor of fragile object ran on wrong thread");
136            }
137        }
138    }
139}
140
141impl<T> From<T> for Fragile<T> {
142    #[inline]
143    fn from(t: T) -> Fragile<T> {
144        Fragile::new(t)
145    }
146}
147
148impl<T: Clone> Clone for Fragile<T> {
149    #[inline]
150    fn clone(&self) -> Fragile<T> {
151        Fragile::new(self.get().clone())
152    }
153}
154
155impl<T: Default> Default for Fragile<T> {
156    #[inline]
157    fn default() -> Fragile<T> {
158        Fragile::new(T::default())
159    }
160}
161
162impl<T: PartialEq> PartialEq for Fragile<T> {
163    #[inline]
164    fn eq(&self, other: &Fragile<T>) -> bool {
165        *self.get() == *other.get()
166    }
167}
168
169impl<T: Eq> Eq for Fragile<T> {}
170
171impl<T: PartialOrd> PartialOrd for Fragile<T> {
172    #[inline]
173    fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
174        self.get().partial_cmp(other.get())
175    }
176
177    #[inline]
178    fn lt(&self, other: &Fragile<T>) -> bool {
179        *self.get() < *other.get()
180    }
181
182    #[inline]
183    fn le(&self, other: &Fragile<T>) -> bool {
184        *self.get() <= *other.get()
185    }
186
187    #[inline]
188    fn gt(&self, other: &Fragile<T>) -> bool {
189        *self.get() > *other.get()
190    }
191
192    #[inline]
193    fn ge(&self, other: &Fragile<T>) -> bool {
194        *self.get() >= *other.get()
195    }
196}
197
198impl<T: Ord> Ord for Fragile<T> {
199    #[inline]
200    fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
201        self.get().cmp(other.get())
202    }
203}
204
205impl<T: fmt::Display> fmt::Display for Fragile<T> {
206    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
207        fmt::Display::fmt(self.get(), f)
208    }
209}
210
211impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
212    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
213        match self.try_get() {
214            Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
215            Err(..) => {
216                struct InvalidPlaceholder;
217                impl fmt::Debug for InvalidPlaceholder {
218                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219                        f.write_str("<invalid thread>")
220                    }
221                }
222
223                f.debug_struct("Fragile")
224                    .field("value", &InvalidPlaceholder)
225                    .finish()
226            }
227        }
228    }
229}
230
231// this type is sync because access can only ever happy from the same thread
232// that created it originally.  All other threads will be able to safely
233// call some basic operations on the reference and they will fail.
234unsafe impl<T> Sync for Fragile<T> {}
235
236// The entire point of this type is to be Send
237#[allow(clippy::non_send_fields_in_send_ty)]
238unsafe impl<T> Send for Fragile<T> {}
239
240#[test]
241fn test_basic() {
242    use std::thread;
243    let val = Fragile::new(true);
244    assert_eq!(val.to_string(), "true");
245    assert_eq!(val.get(), &true);
246    assert!(val.try_get().is_ok());
247    thread::spawn(move || {
248        assert!(val.try_get().is_err());
249    })
250    .join()
251    .unwrap();
252}
253
254#[test]
255fn test_mut() {
256    let mut val = Fragile::new(true);
257    *val.get_mut() = false;
258    assert_eq!(val.to_string(), "false");
259    assert_eq!(val.get(), &false);
260}
261
262#[test]
263#[should_panic]
264fn test_access_other_thread() {
265    use std::thread;
266    let val = Fragile::new(true);
267    thread::spawn(move || {
268        val.get();
269    })
270    .join()
271    .unwrap();
272}
273
274#[test]
275fn test_noop_drop_elsewhere() {
276    use std::thread;
277    let val = Fragile::new(true);
278    thread::spawn(move || {
279        // force the move
280        val.try_get().ok();
281    })
282    .join()
283    .unwrap();
284}
285
286#[test]
287fn test_panic_on_drop_elsewhere() {
288    use std::sync::atomic::{AtomicBool, Ordering};
289    use std::sync::Arc;
290    use std::thread;
291    let was_called = Arc::new(AtomicBool::new(false));
292    struct X(Arc<AtomicBool>);
293    impl Drop for X {
294        fn drop(&mut self) {
295            self.0.store(true, Ordering::SeqCst);
296        }
297    }
298    let val = Fragile::new(X(was_called.clone()));
299    assert!(thread::spawn(move || {
300        val.try_get().ok();
301    })
302    .join()
303    .is_err());
304    assert!(!was_called.load(Ordering::SeqCst));
305}
306
307#[test]
308fn test_rc_sending() {
309    use std::rc::Rc;
310    use std::sync::mpsc::channel;
311    use std::thread;
312
313    let val = Fragile::new(Rc::new(true));
314    let (tx, rx) = channel();
315
316    let thread = thread::spawn(move || {
317        assert!(val.try_get().is_err());
318        let here = val;
319        tx.send(here).unwrap();
320    });
321
322    let rv = rx.recv().unwrap();
323    assert!(**rv.get());
324
325    thread.join().unwrap();
326}