fragile/
sticky.rs

1#![allow(clippy::unit_arg)]
2
3use std::cmp;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::num::NonZeroUsize;
8
9use crate::errors::InvalidThreadAccess;
10use crate::registry;
11use crate::thread_id;
12use crate::StackToken;
13
14/// A [`Sticky<T>`] keeps a value T stored in a thread.
15///
16/// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a
17/// similar interface.  The difference is that whereas [`Fragile`](crate::Fragile) has
18/// its destructor called in the thread where the value was sent, a
19/// [`Sticky`] that is moved to another thread will have the internal
20/// destructor called when the originating thread tears down.
21///
22/// Because [`Sticky`] allows values to be kept alive for longer than the
23/// [`Sticky`] itself, it requires all its contents to be `'static` for
24/// soundness.  More importantly it also requires the use of [`StackToken`]s.
25/// For information about how to use stack tokens and why they are needed,
26/// refer to [`stack_token!`](crate::stack_token).
27///
28/// As this uses TLS internally the general rules about the platform limitations
29/// of destructors for TLS apply.
30pub struct Sticky<T: 'static> {
31    item_id: registry::ItemId,
32    thread_id: NonZeroUsize,
33    _marker: PhantomData<*mut T>,
34}
35
36impl<T> Drop for Sticky<T> {
37    fn drop(&mut self) {
38        // if the type needs dropping we can only do so on the right thread.
39        // worst case we leak the value until the thread dies when drop will be
40        // called by the registry.
41        if mem::needs_drop::<T>() {
42            unsafe {
43                if self.is_valid() {
44                    self.unsafe_take_value();
45                }
46            }
47        }
48    }
49}
50
51impl<T> Sticky<T> {
52    /// Creates a new [`Sticky`] wrapping a `value`.
53    ///
54    /// The value that is moved into the [`Sticky`] can be non `Send` and
55    /// will be anchored to the thread that created the object.  If the
56    /// sticky wrapper type ends up being send from thread to thread
57    /// only the original thread can interact with the value.
58    pub fn new(value: T) -> Self {
59        let entry = registry::Entry {
60            ptr: Box::into_raw(Box::new(value)).cast(),
61            drop: |ptr| {
62                let ptr = ptr.cast::<T>();
63                // SAFETY: This callback will only be called once, with the
64                // above pointer.
65                drop(unsafe { Box::from_raw(ptr) });
66            },
67        };
68
69        let thread_id = thread_id::get();
70        let item_id = registry::insert(entry);
71
72        Sticky {
73            item_id,
74            thread_id,
75            _marker: PhantomData,
76        }
77    }
78
79    #[inline(always)]
80    fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R {
81        self.assert_thread();
82
83        registry::with(self.item_id, |entry| f(entry.ptr.cast::<T>()))
84    }
85
86    /// Returns `true` if the access is valid.
87    ///
88    /// This will be `false` if the value was sent to another thread.
89    #[inline(always)]
90    pub fn is_valid(&self) -> bool {
91        thread_id::get() == self.thread_id
92    }
93
94    #[inline(always)]
95    fn assert_thread(&self) {
96        if !self.is_valid() {
97            panic!("trying to access wrapped value in sticky container from incorrect thread.");
98        }
99    }
100
101    /// Consumes the `Sticky`, returning the wrapped value.
102    ///
103    /// # Panics
104    ///
105    /// Panics if called from a different thread than the one where the
106    /// original value was created.
107    pub fn into_inner(mut self) -> T {
108        self.assert_thread();
109        unsafe {
110            let rv = self.unsafe_take_value();
111            mem::forget(self);
112            rv
113        }
114    }
115
116    unsafe fn unsafe_take_value(&mut self) -> T {
117        let ptr = registry::try_remove(self.item_id).unwrap().ptr.cast::<T>();
118        *Box::from_raw(ptr)
119    }
120
121    /// Consumes the `Sticky`, returning the wrapped value if successful.
122    ///
123    /// The wrapped value is returned if this is called from the same thread
124    /// as the one where the original value was created, otherwise the
125    /// `Sticky` is returned as `Err(self)`.
126    pub fn try_into_inner(self) -> Result<T, Self> {
127        if self.is_valid() {
128            Ok(self.into_inner())
129        } else {
130            Err(self)
131        }
132    }
133
134    /// Immutably borrows the wrapped value.
135    ///
136    /// # Panics
137    ///
138    /// Panics if the calling thread is not the one that wrapped the value.
139    /// For a non-panicking variant, use [`try_get`](#method.try_get`).
140    pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
141        self.with_value(|value| unsafe { &*value })
142    }
143
144    /// Mutably borrows the wrapped value.
145    ///
146    /// # Panics
147    ///
148    /// Panics if the calling thread is not the one that wrapped the value.
149    /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
150    pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
151        self.with_value(|value| unsafe { &mut *value })
152    }
153
154    /// Tries to immutably borrow the wrapped value.
155    ///
156    /// Returns `None` if the calling thread is not the one that wrapped the value.
157    pub fn try_get<'stack>(
158        &'stack self,
159        _proof: &'stack StackToken,
160    ) -> Result<&'stack T, InvalidThreadAccess> {
161        if self.is_valid() {
162            Ok(self.with_value(|value| unsafe { &*value }))
163        } else {
164            Err(InvalidThreadAccess)
165        }
166    }
167
168    /// Tries to mutably borrow the wrapped value.
169    ///
170    /// Returns `None` if the calling thread is not the one that wrapped the value.
171    pub fn try_get_mut<'stack>(
172        &'stack mut self,
173        _proof: &'stack StackToken,
174    ) -> Result<&'stack mut T, InvalidThreadAccess> {
175        if self.is_valid() {
176            Ok(self.with_value(|value| unsafe { &mut *value }))
177        } else {
178            Err(InvalidThreadAccess)
179        }
180    }
181}
182
183impl<T> From<T> for Sticky<T> {
184    #[inline]
185    fn from(t: T) -> Sticky<T> {
186        Sticky::new(t)
187    }
188}
189
190impl<T: Clone> Clone for Sticky<T> {
191    #[inline]
192    fn clone(&self) -> Sticky<T> {
193        crate::stack_token!(tok);
194        Sticky::new(self.get(tok).clone())
195    }
196}
197
198impl<T: Default> Default for Sticky<T> {
199    #[inline]
200    fn default() -> Sticky<T> {
201        Sticky::new(T::default())
202    }
203}
204
205impl<T: PartialEq> PartialEq for Sticky<T> {
206    #[inline]
207    fn eq(&self, other: &Sticky<T>) -> bool {
208        crate::stack_token!(tok);
209        *self.get(tok) == *other.get(tok)
210    }
211}
212
213impl<T: Eq> Eq for Sticky<T> {}
214
215impl<T: PartialOrd> PartialOrd for Sticky<T> {
216    #[inline]
217    fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
218        crate::stack_token!(tok);
219        self.get(tok).partial_cmp(other.get(tok))
220    }
221
222    #[inline]
223    fn lt(&self, other: &Sticky<T>) -> bool {
224        crate::stack_token!(tok);
225        *self.get(tok) < *other.get(tok)
226    }
227
228    #[inline]
229    fn le(&self, other: &Sticky<T>) -> bool {
230        crate::stack_token!(tok);
231        *self.get(tok) <= *other.get(tok)
232    }
233
234    #[inline]
235    fn gt(&self, other: &Sticky<T>) -> bool {
236        crate::stack_token!(tok);
237        *self.get(tok) > *other.get(tok)
238    }
239
240    #[inline]
241    fn ge(&self, other: &Sticky<T>) -> bool {
242        crate::stack_token!(tok);
243        *self.get(tok) >= *other.get(tok)
244    }
245}
246
247impl<T: Ord> Ord for Sticky<T> {
248    #[inline]
249    fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
250        crate::stack_token!(tok);
251        self.get(tok).cmp(other.get(tok))
252    }
253}
254
255impl<T: fmt::Display> fmt::Display for Sticky<T> {
256    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
257        crate::stack_token!(tok);
258        fmt::Display::fmt(self.get(tok), f)
259    }
260}
261
262impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
263    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
264        crate::stack_token!(tok);
265        match self.try_get(tok) {
266            Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
267            Err(..) => {
268                struct InvalidPlaceholder;
269                impl fmt::Debug for InvalidPlaceholder {
270                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
271                        f.write_str("<invalid thread>")
272                    }
273                }
274
275                f.debug_struct("Sticky")
276                    .field("value", &InvalidPlaceholder)
277                    .finish()
278            }
279        }
280    }
281}
282
283// similar as for fragile the type is sync because it only accesses TLS data
284// which is thread local.  There is nothing that needs to be synchronized.
285unsafe impl<T> Sync for Sticky<T> {}
286
287// The entire point of this type is to be Send
288unsafe impl<T> Send for Sticky<T> {}
289
290#[test]
291fn test_basic() {
292    use std::thread;
293    let val = Sticky::new(true);
294    crate::stack_token!(tok);
295    assert_eq!(val.to_string(), "true");
296    assert_eq!(val.get(tok), &true);
297    assert!(val.try_get(tok).is_ok());
298    thread::spawn(move || {
299        crate::stack_token!(tok);
300        assert!(val.try_get(tok).is_err());
301    })
302    .join()
303    .unwrap();
304}
305
306#[test]
307fn test_mut() {
308    let mut val = Sticky::new(true);
309    crate::stack_token!(tok);
310    *val.get_mut(tok) = false;
311    assert_eq!(val.to_string(), "false");
312    assert_eq!(val.get(tok), &false);
313}
314
315#[test]
316#[should_panic]
317fn test_access_other_thread() {
318    use std::thread;
319    let val = Sticky::new(true);
320    thread::spawn(move || {
321        crate::stack_token!(tok);
322        val.get(tok);
323    })
324    .join()
325    .unwrap();
326}
327
328#[test]
329fn test_drop_same_thread() {
330    use std::sync::atomic::{AtomicBool, Ordering};
331    use std::sync::Arc;
332    let was_called = Arc::new(AtomicBool::new(false));
333    struct X(Arc<AtomicBool>);
334    impl Drop for X {
335        fn drop(&mut self) {
336            self.0.store(true, Ordering::SeqCst);
337        }
338    }
339    let val = Sticky::new(X(was_called.clone()));
340    mem::drop(val);
341    assert!(was_called.load(Ordering::SeqCst));
342}
343
344#[test]
345fn test_noop_drop_elsewhere() {
346    use std::sync::atomic::{AtomicBool, Ordering};
347    use std::sync::Arc;
348    use std::thread;
349
350    let was_called = Arc::new(AtomicBool::new(false));
351
352    {
353        let was_called = was_called.clone();
354        thread::spawn(move || {
355            struct X(Arc<AtomicBool>);
356            impl Drop for X {
357                fn drop(&mut self) {
358                    self.0.store(true, Ordering::SeqCst);
359                }
360            }
361
362            let val = Sticky::new(X(was_called.clone()));
363            assert!(thread::spawn(move || {
364                // moves it here but do not deallocate
365                crate::stack_token!(tok);
366                val.try_get(tok).ok();
367            })
368            .join()
369            .is_ok());
370
371            assert!(!was_called.load(Ordering::SeqCst));
372        })
373        .join()
374        .unwrap();
375    }
376
377    assert!(was_called.load(Ordering::SeqCst));
378}
379
380#[test]
381fn test_rc_sending() {
382    use std::rc::Rc;
383    use std::thread;
384    let val = Sticky::new(Rc::new(true));
385    thread::spawn(move || {
386        crate::stack_token!(tok);
387        assert!(val.try_get(tok).is_err());
388    })
389    .join()
390    .unwrap();
391}
392
393#[test]
394fn test_two_stickies() {
395    struct Wat;
396
397    impl Drop for Wat {
398        fn drop(&mut self) {
399            // do nothing
400        }
401    }
402
403    let s1 = Sticky::new(Wat);
404    let s2 = Sticky::new(Wat);
405
406    // make sure all is well
407
408    drop(s1);
409    drop(s2);
410}
411
412#[test]
413fn test_thread_spawn() {
414    use crate::{stack_token, Sticky};
415    use std::{mem::ManuallyDrop, thread};
416
417    let dummy_sticky = thread::spawn(|| Sticky::new(())).join().unwrap();
418    let sticky_string = ManuallyDrop::new(Sticky::new(String::from("Hello World")));
419    stack_token!(t);
420
421    let hello: &str = sticky_string.get(t);
422
423    assert_eq!(hello, "Hello World");
424    drop(dummy_sticky);
425    assert_eq!(hello, "Hello World");
426}