fragile/
sticky.rs
1use std::cell::UnsafeCell;
2use std::cmp;
3use std::collections::HashMap;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::errors::InvalidThreadAccess;
10
11fn next_item_id() -> usize {
12 static mut COUNTER: AtomicUsize = AtomicUsize::new(0);
13 unsafe { COUNTER.fetch_add(1, Ordering::SeqCst) }
14}
15
16type RegistryMap = HashMap<usize, (UnsafeCell<*mut ()>, Box<dyn Fn(&UnsafeCell<*mut ()>)>)>;
17
18struct Registry(RegistryMap);
19
20impl Drop for Registry {
21 fn drop(&mut self) {
22 for (_, value) in self.0.iter() {
23 (value.1)(&value.0);
24 }
25 }
26}
27
28thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(Default::default())));
29
30pub struct Sticky<T> {
41 item_id: usize,
42 _marker: PhantomData<*mut T>,
43}
44
45impl<T> Drop for Sticky<T> {
46 fn drop(&mut self) {
47 if mem::needs_drop::<T>() {
48 unsafe {
49 if self.is_valid() {
50 self.unsafe_take_value();
51 }
52 }
53 }
54 }
55}
56
57impl<T> Sticky<T> {
58 pub fn new(value: T) -> Self {
65 let item_id = next_item_id();
66 REGISTRY.with(|registry| unsafe {
67 (*registry.get()).0.insert(
68 item_id,
69 (
70 UnsafeCell::new(Box::into_raw(Box::new(value)) as *mut _),
71 Box::new(|cell| {
72 let b: Box<T> = Box::from_raw(*(cell.get() as *mut *mut T));
73 mem::drop(b);
74 }),
75 ),
76 );
77 });
78 Sticky {
79 item_id,
80 _marker: PhantomData,
81 }
82 }
83
84 #[inline(always)]
85 fn with_value<F: FnOnce(&UnsafeCell<Box<T>>) -> R, R>(&self, f: F) -> R {
86 REGISTRY.with(|registry| unsafe {
87 let reg = &(*(*registry).get()).0;
88 if let Some(item) = reg.get(&self.item_id) {
89 f(&*(&item.0 as *const UnsafeCell<*mut ()> as *const UnsafeCell<Box<T>>))
90 } else {
91 panic!("trying to access wrapped value in sticky container from incorrect thread.");
92 }
93 })
94 }
95
96 #[inline(always)]
100 pub fn is_valid(&self) -> bool {
101 unsafe {
103 REGISTRY
104 .try_with(|registry| (*registry.get()).0.contains_key(&self.item_id))
105 .unwrap_or(false)
106 }
107 }
108
109 #[inline(always)]
110 fn assert_thread(&self) {
111 if !self.is_valid() {
112 panic!("trying to access wrapped value in sticky container from incorrect thread.");
113 }
114 }
115
116 pub fn into_inner(mut self) -> T {
123 self.assert_thread();
124 unsafe {
125 let rv = self.unsafe_take_value();
126 mem::forget(self);
127 rv
128 }
129 }
130
131 unsafe fn unsafe_take_value(&mut self) -> T {
132 let ptr = REGISTRY
133 .with(|registry| (*registry.get()).0.remove(&self.item_id))
134 .unwrap()
135 .0
136 .into_inner();
137 let rv = Box::from_raw(ptr as *mut T);
138 *rv
139 }
140
141 pub fn try_into_inner(self) -> Result<T, Self> {
147 if self.is_valid() {
148 Ok(self.into_inner())
149 } else {
150 Err(self)
151 }
152 }
153
154 pub fn get(&self) -> &T {
161 self.with_value(|value| unsafe { &*value.get() })
162 }
163
164 pub fn get_mut(&mut self) -> &mut T {
171 self.with_value(|value| unsafe { &mut *value.get() })
172 }
173
174 pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
178 if self.is_valid() {
179 unsafe { Ok(self.with_value(|value| &*value.get())) }
180 } else {
181 Err(InvalidThreadAccess)
182 }
183 }
184
185 pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
189 if self.is_valid() {
190 unsafe { Ok(self.with_value(|value| &mut *value.get())) }
191 } else {
192 Err(InvalidThreadAccess)
193 }
194 }
195}
196
197impl<T> From<T> for Sticky<T> {
198 #[inline]
199 fn from(t: T) -> Sticky<T> {
200 Sticky::new(t)
201 }
202}
203
204impl<T: Clone> Clone for Sticky<T> {
205 #[inline]
206 fn clone(&self) -> Sticky<T> {
207 Sticky::new(self.get().clone())
208 }
209}
210
211impl<T: Default> Default for Sticky<T> {
212 #[inline]
213 fn default() -> Sticky<T> {
214 Sticky::new(T::default())
215 }
216}
217
218impl<T: PartialEq> PartialEq for Sticky<T> {
219 #[inline]
220 fn eq(&self, other: &Sticky<T>) -> bool {
221 *self.get() == *other.get()
222 }
223}
224
225impl<T: Eq> Eq for Sticky<T> {}
226
227impl<T: PartialOrd> PartialOrd for Sticky<T> {
228 #[inline]
229 fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
230 self.get().partial_cmp(&*other.get())
231 }
232
233 #[inline]
234 fn lt(&self, other: &Sticky<T>) -> bool {
235 *self.get() < *other.get()
236 }
237
238 #[inline]
239 fn le(&self, other: &Sticky<T>) -> bool {
240 *self.get() <= *other.get()
241 }
242
243 #[inline]
244 fn gt(&self, other: &Sticky<T>) -> bool {
245 *self.get() > *other.get()
246 }
247
248 #[inline]
249 fn ge(&self, other: &Sticky<T>) -> bool {
250 *self.get() >= *other.get()
251 }
252}
253
254impl<T: Ord> Ord for Sticky<T> {
255 #[inline]
256 fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
257 self.get().cmp(&*other.get())
258 }
259}
260
261impl<T: fmt::Display> fmt::Display for Sticky<T> {
262 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
263 fmt::Display::fmt(self.get(), f)
264 }
265}
266
267impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
268 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
269 match self.try_get() {
270 Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
271 Err(..) => {
272 struct InvalidPlaceholder;
273 impl fmt::Debug for InvalidPlaceholder {
274 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275 f.write_str("<invalid thread>")
276 }
277 }
278
279 f.debug_struct("Sticky")
280 .field("value", &InvalidPlaceholder)
281 .finish()
282 }
283 }
284 }
285}
286
287unsafe impl<T> Sync for Sticky<T> {}
290
291unsafe impl<T> Send for Sticky<T> {}
293
294#[test]
295fn test_basic() {
296 use std::thread;
297 let val = Sticky::new(true);
298 assert_eq!(val.to_string(), "true");
299 assert_eq!(val.get(), &true);
300 assert!(val.try_get().is_ok());
301 thread::spawn(move || {
302 assert!(val.try_get().is_err());
303 })
304 .join()
305 .unwrap();
306}
307
308#[test]
309fn test_mut() {
310 let mut val = Sticky::new(true);
311 *val.get_mut() = false;
312 assert_eq!(val.to_string(), "false");
313 assert_eq!(val.get(), &false);
314}
315
316#[test]
317#[should_panic]
318fn test_access_other_thread() {
319 use std::thread;
320 let val = Sticky::new(true);
321 thread::spawn(move || {
322 val.get();
323 })
324 .join()
325 .unwrap();
326}
327
328#[test]
329fn test_drop_same_thread() {
330 use std::sync::atomic::{AtomicBool, Ordering};
331 use std::sync::Arc;
332 let was_called = Arc::new(AtomicBool::new(false));
333 struct X(Arc<AtomicBool>);
334 impl Drop for X {
335 fn drop(&mut self) {
336 self.0.store(true, Ordering::SeqCst);
337 }
338 }
339 let val = Sticky::new(X(was_called.clone()));
340 mem::drop(val);
341 assert_eq!(was_called.load(Ordering::SeqCst), true);
342}
343
344#[test]
345fn test_noop_drop_elsewhere() {
346 use std::sync::atomic::{AtomicBool, Ordering};
347 use std::sync::Arc;
348 use std::thread;
349
350 let was_called = Arc::new(AtomicBool::new(false));
351
352 {
353 let was_called = was_called.clone();
354 thread::spawn(move || {
355 struct X(Arc<AtomicBool>);
356 impl Drop for X {
357 fn drop(&mut self) {
358 self.0.store(true, Ordering::SeqCst);
359 }
360 }
361
362 let val = Sticky::new(X(was_called.clone()));
363 assert!(thread::spawn(move || {
364 val.try_get().ok();
366 })
367 .join()
368 .is_ok());
369
370 assert_eq!(was_called.load(Ordering::SeqCst), false);
371 })
372 .join()
373 .unwrap();
374 }
375
376 assert_eq!(was_called.load(Ordering::SeqCst), true);
377}
378
379#[test]
380fn test_rc_sending() {
381 use std::rc::Rc;
382 use std::thread;
383 let val = Sticky::new(Rc::new(true));
384 thread::spawn(move || {
385 assert!(val.try_get().is_err());
386 })
387 .join()
388 .unwrap();
389}