fragile/
sticky.rs

1use std::cell::UnsafeCell;
2use std::cmp;
3use std::collections::HashMap;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::errors::InvalidThreadAccess;
10
11fn next_item_id() -> usize {
12    static mut COUNTER: AtomicUsize = AtomicUsize::new(0);
13    unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
14}
15
16type RegistryMap = HashMap<usize, (UnsafeCell<*mut ()>, Box<dyn Fn(&UnsafeCell<*mut ()>)>)>;
17
18struct Registry(RegistryMap);
19
20impl Drop for Registry {
21    fn drop(&mut self) {
22        for (_, value) in self.0.iter() {
23            (value.1)(&value.0);
24        }
25    }
26}
27
28thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(Default::default())));
29
30/// A `Sticky<T>` keeps a value T stored in a thread.
31///
32/// This type works similar in nature to `Fragile<T>` and exposes the
33/// same interface.  The difference is that whereas `Fragile<T>` has
34/// its destructor called in the thread where the value was sent, a
35/// `Sticky<T>` that is moved to another thread will have the internal
36/// destructor called when the originating thread tears down.
37///
38/// As this uses TLS internally the general rules about the platform limitations
39/// of destructors for TLS apply.
40pub struct Sticky<T> {
41    item_id: usize,
42    _marker: PhantomData<*mut T>,
43}
44
45impl<T> Drop for Sticky<T> {
46    fn drop(&mut self) {
47        if mem::needs_drop::<T>() {
48            unsafe {
49                if self.is_valid() {
50                    self.unsafe_take_value();
51                }
52            }
53        }
54    }
55}
56
57impl<T> Sticky<T> {
58    /// Creates a new `Sticky` wrapping a `value`.
59    ///
60    /// The value that is moved into the `Sticky` can be non `Send` and
61    /// will be anchored to the thread that created the object.  If the
62    /// sticky wrapper type ends up being send from thread to thread
63    /// only the original thread can interact with the value.
64    pub fn new(value: T) -> Self {
65        let item_id = next_item_id();
66        REGISTRY.with(|registry| unsafe {
67            (*registry.get()).0.insert(
68                item_id,
69                (
70                    UnsafeCell::new(Box::into_raw(Box::new(value)) as *mut _),
71                    Box::new(|cell| {
72                        let b: Box<T> = Box::from_raw(*(cell.get() as *mut *mut T));
73                        mem::drop(b);
74                    }),
75                ),
76            );
77        });
78        Sticky {
79            item_id,
80            _marker: PhantomData,
81        }
82    }
83
84    #[inline(always)]
85    fn with_value<F: FnOnce(&UnsafeCell<Box<T>>) -> R, R>(&self, f: F) -> R {
86        REGISTRY.with(|registry| unsafe {
87            let reg = &(*(*registry).get()).0;
88            if let Some(item) = reg.get(&self.item_id) {
89                f(&*(&item.0 as *const UnsafeCell<*mut ()> as *const UnsafeCell<Box<T>>))
90            } else {
91                panic!("trying to access wrapped value in sticky container from incorrect thread.");
92            }
93        })
94    }
95
96    /// Returns `true` if the access is valid.
97    ///
98    /// This will be `false` if the value was sent to another thread.
99    #[inline(always)]
100    pub fn is_valid(&self) -> bool {
101        // We use `try-with` here to avoid crashing if the TLS is already tearing down.
102        unsafe {
103            REGISTRY
104                .try_with(|registry| (*registry.get()).0.contains_key(&self.item_id))
105                .unwrap_or(false)
106        }
107    }
108
109    #[inline(always)]
110    fn assert_thread(&self) {
111        if !self.is_valid() {
112            panic!("trying to access wrapped value in sticky container from incorrect thread.");
113        }
114    }
115
116    /// Consumes the `Sticky`, returning the wrapped value.
117    ///
118    /// # Panics
119    ///
120    /// Panics if called from a different thread than the one where the
121    /// original value was created.
122    pub fn into_inner(mut self) -> T {
123        self.assert_thread();
124        unsafe {
125            let rv = self.unsafe_take_value();
126            mem::forget(self);
127            rv
128        }
129    }
130
131    unsafe fn unsafe_take_value(&mut self) -> T {
132        let ptr = REGISTRY
133            .with(|registry| (*registry.get()).0.remove(&self.item_id))
134            .unwrap()
135            .0
136            .into_inner();
137        let rv = Box::from_raw(ptr as *mut T);
138        *rv
139    }
140
141    /// Consumes the `Sticky`, returning the wrapped value if successful.
142    ///
143    /// The wrapped value is returned if this is called from the same thread
144    /// as the one where the original value was created, otherwise the
145    /// `Sticky` is returned as `Err(self)`.
146    pub fn try_into_inner(self) -> Result<T, Self> {
147        if self.is_valid() {
148            Ok(self.into_inner())
149        } else {
150            Err(self)
151        }
152    }
153
154    /// Immutably borrows the wrapped value.
155    ///
156    /// # Panics
157    ///
158    /// Panics if the calling thread is not the one that wrapped the value.
159    /// For a non-panicking variant, use [`try_get`](#method.try_get`).
160    pub fn get(&self) -> &T {
161        self.with_value(|value| unsafe { &*value.get() })
162    }
163
164    /// Mutably borrows the wrapped value.
165    ///
166    /// # Panics
167    ///
168    /// Panics if the calling thread is not the one that wrapped the value.
169    /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
170    pub fn get_mut(&mut self) -> &mut T {
171        self.with_value(|value| unsafe { &mut *value.get() })
172    }
173
174    /// Tries to immutably borrow the wrapped value.
175    ///
176    /// Returns `None` if the calling thread is not the one that wrapped the value.
177    pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
178        if self.is_valid() {
179            unsafe { Ok(self.with_value(|value| &*value.get())) }
180        } else {
181            Err(InvalidThreadAccess)
182        }
183    }
184
185    /// Tries to mutably borrow the wrapped value.
186    ///
187    /// Returns `None` if the calling thread is not the one that wrapped the value.
188    pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
189        if self.is_valid() {
190            unsafe { Ok(self.with_value(|value| &mut *value.get())) }
191        } else {
192            Err(InvalidThreadAccess)
193        }
194    }
195}
196
197impl<T> From<T> for Sticky<T> {
198    #[inline]
199    fn from(t: T) -> Sticky<T> {
200        Sticky::new(t)
201    }
202}
203
204impl<T: Clone> Clone for Sticky<T> {
205    #[inline]
206    fn clone(&self) -> Sticky<T> {
207        Sticky::new(self.get().clone())
208    }
209}
210
211impl<T: Default> Default for Sticky<T> {
212    #[inline]
213    fn default() -> Sticky<T> {
214        Sticky::new(T::default())
215    }
216}
217
218impl<T: PartialEq> PartialEq for Sticky<T> {
219    #[inline]
220    fn eq(&self, other: &Sticky<T>) -> bool {
221        *self.get() == *other.get()
222    }
223}
224
225impl<T: Eq> Eq for Sticky<T> {}
226
227impl<T: PartialOrd> PartialOrd for Sticky<T> {
228    #[inline]
229    fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
230        self.get().partial_cmp(&*other.get())
231    }
232
233    #[inline]
234    fn lt(&self, other: &Sticky<T>) -> bool {
235        *self.get() < *other.get()
236    }
237
238    #[inline]
239    fn le(&self, other: &Sticky<T>) -> bool {
240        *self.get() <= *other.get()
241    }
242
243    #[inline]
244    fn gt(&self, other: &Sticky<T>) -> bool {
245        *self.get() > *other.get()
246    }
247
248    #[inline]
249    fn ge(&self, other: &Sticky<T>) -> bool {
250        *self.get() >= *other.get()
251    }
252}
253
254impl<T: Ord> Ord for Sticky<T> {
255    #[inline]
256    fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
257        self.get().cmp(&*other.get())
258    }
259}
260
261impl<T: fmt::Display> fmt::Display for Sticky<T> {
262    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
263        fmt::Display::fmt(self.get(), f)
264    }
265}
266
267impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
268    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
269        match self.try_get() {
270            Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
271            Err(..) => {
272                struct InvalidPlaceholder;
273                impl fmt::Debug for InvalidPlaceholder {
274                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275                        f.write_str("<invalid thread>")
276                    }
277                }
278
279                f.debug_struct("Sticky")
280                    .field("value", &InvalidPlaceholder)
281                    .finish()
282            }
283        }
284    }
285}
286
287// similar as for fragile ths type is sync because it only accesses TLS data
288// which is thread local.  There is nothing that needs to be synchronized.
289unsafe impl<T> Sync for Sticky<T> {}
290
291// The entire point of this type is to be Send
292unsafe impl<T> Send for Sticky<T> {}
293
294#[test]
295fn test_basic() {
296    use std::thread;
297    let val = Sticky::new(true);
298    assert_eq!(val.to_string(), "true");
299    assert_eq!(val.get(), &true);
300    assert!(val.try_get().is_ok());
301    thread::spawn(move || {
302        assert!(val.try_get().is_err());
303    })
304    .join()
305    .unwrap();
306}
307
308#[test]
309fn test_mut() {
310    let mut val = Sticky::new(true);
311    *val.get_mut() = false;
312    assert_eq!(val.to_string(), "false");
313    assert_eq!(val.get(), &false);
314}
315
316#[test]
317#[should_panic]
318fn test_access_other_thread() {
319    use std::thread;
320    let val = Sticky::new(true);
321    thread::spawn(move || {
322        val.get();
323    })
324    .join()
325    .unwrap();
326}
327
328#[test]
329fn test_drop_same_thread() {
330    use std::sync::atomic::{AtomicBool, Ordering};
331    use std::sync::Arc;
332    let was_called = Arc::new(AtomicBool::new(false));
333    struct X(Arc<AtomicBool>);
334    impl Drop for X {
335        fn drop(&mut self) {
336            self.0.store(true, Ordering::SeqCst);
337        }
338    }
339    let val = Sticky::new(X(was_called.clone()));
340    mem::drop(val);
341    assert_eq!(was_called.load(Ordering::SeqCst), true);
342}
343
344#[test]
345fn test_noop_drop_elsewhere() {
346    use std::sync::atomic::{AtomicBool, Ordering};
347    use std::sync::Arc;
348    use std::thread;
349
350    let was_called = Arc::new(AtomicBool::new(false));
351
352    {
353        let was_called = was_called.clone();
354        thread::spawn(move || {
355            struct X(Arc<AtomicBool>);
356            impl Drop for X {
357                fn drop(&mut self) {
358                    self.0.store(true, Ordering::SeqCst);
359                }
360            }
361
362            let val = Sticky::new(X(was_called.clone()));
363            assert!(thread::spawn(move || {
364                // moves it here but do not deallocate
365                val.try_get().ok();
366            })
367            .join()
368            .is_ok());
369
370            assert_eq!(was_called.load(Ordering::SeqCst), false);
371        })
372        .join()
373        .unwrap();
374    }
375
376    assert_eq!(was_called.load(Ordering::SeqCst), true);
377}
378
379#[test]
380fn test_rc_sending() {
381    use std::rc::Rc;
382    use std::thread;
383    let val = Sticky::new(Rc::new(true));
384    thread::spawn(move || {
385        assert!(val.try_get().is_err());
386    })
387    .join()
388    .unwrap();
389}