1#![allow(clippy::unit_arg)]
2
3use std::cmp;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::num::NonZeroUsize;
8
9use crate::errors::InvalidThreadAccess;
10use crate::registry;
11use crate::thread_id;
12use crate::StackToken;
13
14pub struct Sticky<T: 'static> {
31 item_id: registry::ItemId,
32 thread_id: NonZeroUsize,
33 _marker: PhantomData<*mut T>,
34}
35
36impl<T> Drop for Sticky<T> {
37 fn drop(&mut self) {
38 if mem::needs_drop::<T>() {
42 unsafe {
43 if self.is_valid() {
44 self.unsafe_take_value();
45 }
46 }
47 }
48 }
49}
50
51impl<T> Sticky<T> {
52 pub fn new(value: T) -> Self {
59 let entry = registry::Entry {
60 ptr: Box::into_raw(Box::new(value)).cast(),
61 drop: |ptr| {
62 let ptr = ptr.cast::<T>();
63 drop(unsafe { Box::from_raw(ptr) });
66 },
67 };
68
69 let thread_id = thread_id::get();
70 let item_id = registry::insert(entry);
71
72 Sticky {
73 item_id,
74 thread_id,
75 _marker: PhantomData,
76 }
77 }
78
79 #[inline(always)]
80 fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R {
81 self.assert_thread();
82
83 registry::with(self.item_id, |entry| f(entry.ptr.cast::<T>()))
84 }
85
86 #[inline(always)]
90 pub fn is_valid(&self) -> bool {
91 thread_id::get() == self.thread_id
92 }
93
94 #[inline(always)]
95 fn assert_thread(&self) {
96 if !self.is_valid() {
97 panic!("trying to access wrapped value in sticky container from incorrect thread.");
98 }
99 }
100
101 pub fn into_inner(mut self) -> T {
108 self.assert_thread();
109 unsafe {
110 let rv = self.unsafe_take_value();
111 mem::forget(self);
112 rv
113 }
114 }
115
116 unsafe fn unsafe_take_value(&mut self) -> T {
117 let ptr = registry::try_remove(self.item_id).unwrap().ptr.cast::<T>();
118 *Box::from_raw(ptr)
119 }
120
121 pub fn try_into_inner(self) -> Result<T, Self> {
127 if self.is_valid() {
128 Ok(self.into_inner())
129 } else {
130 Err(self)
131 }
132 }
133
134 pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
141 self.with_value(|value| unsafe { &*value })
142 }
143
144 pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
151 self.with_value(|value| unsafe { &mut *value })
152 }
153
154 pub fn try_get<'stack>(
158 &'stack self,
159 _proof: &'stack StackToken,
160 ) -> Result<&'stack T, InvalidThreadAccess> {
161 if self.is_valid() {
162 Ok(self.with_value(|value| unsafe { &*value }))
163 } else {
164 Err(InvalidThreadAccess)
165 }
166 }
167
168 pub fn try_get_mut<'stack>(
172 &'stack mut self,
173 _proof: &'stack StackToken,
174 ) -> Result<&'stack mut T, InvalidThreadAccess> {
175 if self.is_valid() {
176 Ok(self.with_value(|value| unsafe { &mut *value }))
177 } else {
178 Err(InvalidThreadAccess)
179 }
180 }
181}
182
183impl<T> From<T> for Sticky<T> {
184 #[inline]
185 fn from(t: T) -> Sticky<T> {
186 Sticky::new(t)
187 }
188}
189
190impl<T: Clone> Clone for Sticky<T> {
191 #[inline]
192 fn clone(&self) -> Sticky<T> {
193 crate::stack_token!(tok);
194 Sticky::new(self.get(tok).clone())
195 }
196}
197
198impl<T: Default> Default for Sticky<T> {
199 #[inline]
200 fn default() -> Sticky<T> {
201 Sticky::new(T::default())
202 }
203}
204
205impl<T: PartialEq> PartialEq for Sticky<T> {
206 #[inline]
207 fn eq(&self, other: &Sticky<T>) -> bool {
208 crate::stack_token!(tok);
209 *self.get(tok) == *other.get(tok)
210 }
211}
212
213impl<T: Eq> Eq for Sticky<T> {}
214
215impl<T: PartialOrd> PartialOrd for Sticky<T> {
216 #[inline]
217 fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
218 crate::stack_token!(tok);
219 self.get(tok).partial_cmp(other.get(tok))
220 }
221
222 #[inline]
223 fn lt(&self, other: &Sticky<T>) -> bool {
224 crate::stack_token!(tok);
225 *self.get(tok) < *other.get(tok)
226 }
227
228 #[inline]
229 fn le(&self, other: &Sticky<T>) -> bool {
230 crate::stack_token!(tok);
231 *self.get(tok) <= *other.get(tok)
232 }
233
234 #[inline]
235 fn gt(&self, other: &Sticky<T>) -> bool {
236 crate::stack_token!(tok);
237 *self.get(tok) > *other.get(tok)
238 }
239
240 #[inline]
241 fn ge(&self, other: &Sticky<T>) -> bool {
242 crate::stack_token!(tok);
243 *self.get(tok) >= *other.get(tok)
244 }
245}
246
247impl<T: Ord> Ord for Sticky<T> {
248 #[inline]
249 fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
250 crate::stack_token!(tok);
251 self.get(tok).cmp(other.get(tok))
252 }
253}
254
255impl<T: fmt::Display> fmt::Display for Sticky<T> {
256 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
257 crate::stack_token!(tok);
258 fmt::Display::fmt(self.get(tok), f)
259 }
260}
261
262impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
263 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
264 crate::stack_token!(tok);
265 match self.try_get(tok) {
266 Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
267 Err(..) => {
268 struct InvalidPlaceholder;
269 impl fmt::Debug for InvalidPlaceholder {
270 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
271 f.write_str("<invalid thread>")
272 }
273 }
274
275 f.debug_struct("Sticky")
276 .field("value", &InvalidPlaceholder)
277 .finish()
278 }
279 }
280 }
281}
282
283unsafe impl<T> Sync for Sticky<T> {}
286
287unsafe impl<T> Send for Sticky<T> {}
289
290#[test]
291fn test_basic() {
292 use std::thread;
293 let val = Sticky::new(true);
294 crate::stack_token!(tok);
295 assert_eq!(val.to_string(), "true");
296 assert_eq!(val.get(tok), &true);
297 assert!(val.try_get(tok).is_ok());
298 thread::spawn(move || {
299 crate::stack_token!(tok);
300 assert!(val.try_get(tok).is_err());
301 })
302 .join()
303 .unwrap();
304}
305
306#[test]
307fn test_mut() {
308 let mut val = Sticky::new(true);
309 crate::stack_token!(tok);
310 *val.get_mut(tok) = false;
311 assert_eq!(val.to_string(), "false");
312 assert_eq!(val.get(tok), &false);
313}
314
315#[test]
316#[should_panic]
317fn test_access_other_thread() {
318 use std::thread;
319 let val = Sticky::new(true);
320 thread::spawn(move || {
321 crate::stack_token!(tok);
322 val.get(tok);
323 })
324 .join()
325 .unwrap();
326}
327
328#[test]
329fn test_drop_same_thread() {
330 use std::sync::atomic::{AtomicBool, Ordering};
331 use std::sync::Arc;
332 let was_called = Arc::new(AtomicBool::new(false));
333 struct X(Arc<AtomicBool>);
334 impl Drop for X {
335 fn drop(&mut self) {
336 self.0.store(true, Ordering::SeqCst);
337 }
338 }
339 let val = Sticky::new(X(was_called.clone()));
340 mem::drop(val);
341 assert!(was_called.load(Ordering::SeqCst));
342}
343
344#[test]
345fn test_noop_drop_elsewhere() {
346 use std::sync::atomic::{AtomicBool, Ordering};
347 use std::sync::Arc;
348 use std::thread;
349
350 let was_called = Arc::new(AtomicBool::new(false));
351
352 {
353 let was_called = was_called.clone();
354 thread::spawn(move || {
355 struct X(Arc<AtomicBool>);
356 impl Drop for X {
357 fn drop(&mut self) {
358 self.0.store(true, Ordering::SeqCst);
359 }
360 }
361
362 let val = Sticky::new(X(was_called.clone()));
363 assert!(thread::spawn(move || {
364 crate::stack_token!(tok);
366 val.try_get(tok).ok();
367 })
368 .join()
369 .is_ok());
370
371 assert!(!was_called.load(Ordering::SeqCst));
372 })
373 .join()
374 .unwrap();
375 }
376
377 assert!(was_called.load(Ordering::SeqCst));
378}
379
380#[test]
381fn test_rc_sending() {
382 use std::rc::Rc;
383 use std::thread;
384 let val = Sticky::new(Rc::new(true));
385 thread::spawn(move || {
386 crate::stack_token!(tok);
387 assert!(val.try_get(tok).is_err());
388 })
389 .join()
390 .unwrap();
391}
392
393#[test]
394fn test_two_stickies() {
395 struct Wat;
396
397 impl Drop for Wat {
398 fn drop(&mut self) {
399 }
401 }
402
403 let s1 = Sticky::new(Wat);
404 let s2 = Sticky::new(Wat);
405
406 drop(s1);
409 drop(s2);
410}
411
412#[test]
413fn test_thread_spawn() {
414 use crate::{stack_token, Sticky};
415 use std::{mem::ManuallyDrop, thread};
416
417 let dummy_sticky = thread::spawn(|| Sticky::new(())).join().unwrap();
418 let sticky_string = ManuallyDrop::new(Sticky::new(String::from("Hello World")));
419 stack_token!(t);
420
421 let hello: &str = sticky_string.get(t);
422
423 assert_eq!(hello, "Hello World");
424 drop(dummy_sticky);
425 assert_eq!(hello, "Hello World");
426}