tracing_mutex/
stdsync.rs

1//! Tracing mutex wrappers for locks found in `std::sync`.
2//!
3//! This module provides wrappers for `std::sync` primitives with exactly the same API and
4//! functionality as their counterparts, with the exception that their acquisition order is tracked.
5//!
6//! Dedicated wrappers that provide the dependency tracing can be found in the [`tracing`] module.
7//! The original primitives are available from [`std::sync`], imported as [`raw`] for convenience.
8//!
9//! If debug assertions are enabled, this module imports the primitives from [`tracing`], otherwise
10//! it will import from [`raw`].
11//!
12//! ```rust
13//! # use tracing_mutex::stdsync::tracing::Mutex;
14//! # use tracing_mutex::stdsync::tracing::RwLock;
15//! let mutex = Mutex::new(());
16//! mutex.lock().unwrap();
17//!
18//! let rwlock = RwLock::new(());
19//! rwlock.read().unwrap();
20//! ```
21pub use std::sync as raw;
22
23#[cfg(not(debug_assertions))]
24pub use std::sync::{
25    Condvar, Mutex, MutexGuard, Once, OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard,
26};
27
28#[cfg(debug_assertions)]
29pub use tracing::{
30    Condvar, Mutex, MutexGuard, Once, OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard,
31};
32
33/// Dependency tracing versions of [`std::sync`].
34pub mod tracing {
35    use std::fmt;
36    use std::ops::Deref;
37    use std::ops::DerefMut;
38    use std::sync;
39    use std::sync::LockResult;
40    use std::sync::OnceState;
41    use std::sync::PoisonError;
42    use std::sync::TryLockError;
43    use std::sync::TryLockResult;
44    use std::sync::WaitTimeoutResult;
45    use std::time::Duration;
46
47    use crate::BorrowedMutex;
48    use crate::LazyMutexId;
49
50    /// Wrapper for [`std::sync::Mutex`].
51    ///
52    /// Refer to the [crate-level][`crate`] documentation for the differences between this struct and
53    /// the one it wraps.
54    #[derive(Debug, Default)]
55    pub struct Mutex<T> {
56        inner: sync::Mutex<T>,
57        id: LazyMutexId,
58    }
59
60    /// Wrapper for [`std::sync::MutexGuard`].
61    ///
62    /// Refer to the [crate-level][`crate`] documentation for the differences between this struct and
63    /// the one it wraps.
64    #[derive(Debug)]
65    pub struct MutexGuard<'a, T> {
66        inner: sync::MutexGuard<'a, T>,
67        _mutex: BorrowedMutex<'a>,
68    }
69
70    fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T>
71    where
72        F: FnOnce(I) -> T,
73    {
74        match result {
75            Ok(inner) => Ok(mapper(inner)),
76            Err(poisoned) => Err(PoisonError::new(mapper(poisoned.into_inner()))),
77        }
78    }
79
80    fn map_trylockresult<T, I, F>(result: TryLockResult<I>, mapper: F) -> TryLockResult<T>
81    where
82        F: FnOnce(I) -> T,
83    {
84        match result {
85            Ok(inner) => Ok(mapper(inner)),
86            Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
87            Err(TryLockError::Poisoned(poisoned)) => {
88                Err(PoisonError::new(mapper(poisoned.into_inner())).into())
89            }
90        }
91    }
92
93    impl<T> Mutex<T> {
94        /// Create a new tracing mutex with the provided value.
95        pub const fn new(t: T) -> Self {
96            Self {
97                inner: sync::Mutex::new(t),
98                id: LazyMutexId::new(),
99            }
100        }
101
102        /// Wrapper for [`std::sync::Mutex::lock`].
103        ///
104        /// # Panics
105        ///
106        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
107        /// dependency cycle, this method will panic.
108        #[track_caller]
109        pub fn lock(&self) -> LockResult<MutexGuard<T>> {
110            let mutex = self.id.get_borrowed();
111            let result = self.inner.lock();
112
113            let mapper = |guard| MutexGuard {
114                _mutex: mutex,
115                inner: guard,
116            };
117
118            map_lockresult(result, mapper)
119        }
120
121        /// Wrapper for [`std::sync::Mutex::try_lock`].
122        ///
123        /// # Panics
124        ///
125        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
126        /// dependency cycle, this method will panic.
127        #[track_caller]
128        pub fn try_lock(&self) -> TryLockResult<MutexGuard<T>> {
129            let mutex = self.id.get_borrowed();
130            let result = self.inner.try_lock();
131
132            let mapper = |guard| MutexGuard {
133                _mutex: mutex,
134                inner: guard,
135            };
136
137            map_trylockresult(result, mapper)
138        }
139
140        /// Wrapper for [`std::sync::Mutex::is_poisoned`].
141        pub fn is_poisoned(&self) -> bool {
142            self.inner.is_poisoned()
143        }
144
145        /// Return a mutable reference to the underlying data.
146        ///
147        /// This method does not block as the locking is handled compile-time by the type system.
148        pub fn get_mut(&mut self) -> LockResult<&mut T> {
149            self.inner.get_mut()
150        }
151
152        /// Unwrap the mutex and return its inner value.
153        pub fn into_inner(self) -> LockResult<T> {
154            self.inner.into_inner()
155        }
156    }
157
158    impl<T> From<T> for Mutex<T> {
159        fn from(t: T) -> Self {
160            Self::new(t)
161        }
162    }
163
164    impl<'a, T> Deref for MutexGuard<'a, T> {
165        type Target = T;
166
167        fn deref(&self) -> &Self::Target {
168            &self.inner
169        }
170    }
171
172    impl<'a, T> DerefMut for MutexGuard<'a, T> {
173        fn deref_mut(&mut self) -> &mut Self::Target {
174            &mut self.inner
175        }
176    }
177
178    impl<'a, T: fmt::Display> fmt::Display for MutexGuard<'a, T> {
179        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180            self.inner.fmt(f)
181        }
182    }
183
184    /// Wrapper around [`std::sync::Condvar`].
185    ///
186    /// Allows `TracingMutexGuard` to be used with a `Condvar`. Unlike other structs in this module,
187    /// this wrapper does not add any additional dependency tracking or other overhead on top of the
188    /// primitive it wraps. All dependency tracking happens through the mutexes itself.
189    ///
190    /// # Panics
191    ///
192    /// This struct does not add any panics over the base implementation of `Condvar`, but panics due to
193    /// dependency tracking may poison associated mutexes.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// use std::sync::Arc;
199    /// use std::thread;
200    ///
201    /// use tracing_mutex::stdsync::tracing::{Condvar, Mutex};
202    ///
203    /// let pair = Arc::new((Mutex::new(false), Condvar::new()));
204    /// let pair2 = Arc::clone(&pair);
205    ///
206    /// // Spawn a thread that will unlock the condvar
207    /// thread::spawn(move || {
208    ///     let (lock, condvar) = &*pair2;
209    ///     *lock.lock().unwrap() = true;
210    ///     condvar.notify_one();
211    /// });
212    ///
213    /// // Wait until the thread unlocks the condvar
214    /// let (lock, condvar) = &*pair;
215    /// let guard = lock.lock().unwrap();
216    /// let guard = condvar.wait_while(guard, |started| !*started).unwrap();
217    ///
218    /// // Guard should read true now
219    /// assert!(*guard);
220    /// ```
221    #[derive(Debug, Default)]
222    pub struct Condvar(sync::Condvar);
223
224    impl Condvar {
225        /// Creates a new condition variable which is ready to be waited on and notified.
226        pub const fn new() -> Self {
227            Self(sync::Condvar::new())
228        }
229
230        /// Wrapper for [`std::sync::Condvar::wait`].
231        pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> LockResult<MutexGuard<'a, T>> {
232            let MutexGuard { _mutex, inner } = guard;
233
234            map_lockresult(self.0.wait(inner), |inner| MutexGuard { _mutex, inner })
235        }
236
237        /// Wrapper for [`std::sync::Condvar::wait_while`].
238        pub fn wait_while<'a, T, F>(
239            &self,
240            guard: MutexGuard<'a, T>,
241            condition: F,
242        ) -> LockResult<MutexGuard<'a, T>>
243        where
244            F: FnMut(&mut T) -> bool,
245        {
246            let MutexGuard { _mutex, inner } = guard;
247
248            map_lockresult(self.0.wait_while(inner, condition), |inner| MutexGuard {
249                _mutex,
250                inner,
251            })
252        }
253
254        /// Wrapper for [`std::sync::Condvar::wait_timeout`].
255        pub fn wait_timeout<'a, T>(
256            &self,
257            guard: MutexGuard<'a, T>,
258            dur: Duration,
259        ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> {
260            let MutexGuard { _mutex, inner } = guard;
261
262            map_lockresult(self.0.wait_timeout(inner, dur), |(inner, result)| {
263                (MutexGuard { _mutex, inner }, result)
264            })
265        }
266
267        /// Wrapper for [`std::sync::Condvar::wait_timeout_while`].
268        pub fn wait_timeout_while<'a, T, F>(
269            &self,
270            guard: MutexGuard<'a, T>,
271            dur: Duration,
272            condition: F,
273        ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)>
274        where
275            F: FnMut(&mut T) -> bool,
276        {
277            let MutexGuard { _mutex, inner } = guard;
278
279            map_lockresult(
280                self.0.wait_timeout_while(inner, dur, condition),
281                |(inner, result)| (MutexGuard { _mutex, inner }, result),
282            )
283        }
284
285        /// Wrapper for [`std::sync::Condvar::notify_one`].
286        pub fn notify_one(&self) {
287            self.0.notify_one();
288        }
289
290        /// Wrapper for [`std::sync::Condvar::notify_all`].
291        pub fn notify_all(&self) {
292            self.0.notify_all();
293        }
294    }
295
296    /// Wrapper for [`std::sync::RwLock`].
297    #[derive(Debug, Default)]
298    pub struct RwLock<T> {
299        inner: sync::RwLock<T>,
300        id: LazyMutexId,
301    }
302
303    /// Hybrid wrapper for both [`std::sync::RwLockReadGuard`] and [`std::sync::RwLockWriteGuard`].
304    ///
305    /// Please refer to [`RwLockReadGuard`] and [`RwLockWriteGuard`] for usable types.
306    #[derive(Debug)]
307    pub struct TracingRwLockGuard<'a, L> {
308        inner: L,
309        _mutex: BorrowedMutex<'a>,
310    }
311
312    /// Wrapper around [`std::sync::RwLockReadGuard`].
313    pub type RwLockReadGuard<'a, T> = TracingRwLockGuard<'a, sync::RwLockReadGuard<'a, T>>;
314    /// Wrapper around [`std::sync::RwLockWriteGuard`].
315    pub type RwLockWriteGuard<'a, T> = TracingRwLockGuard<'a, sync::RwLockWriteGuard<'a, T>>;
316
317    impl<T> RwLock<T> {
318        pub const fn new(t: T) -> Self {
319            Self {
320                inner: sync::RwLock::new(t),
321                id: LazyMutexId::new(),
322            }
323        }
324
325        /// Wrapper for [`std::sync::RwLock::read`].
326        ///
327        /// # Panics
328        ///
329        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
330        /// dependency cycle, this method will panic.
331        #[track_caller]
332        pub fn read(&self) -> LockResult<RwLockReadGuard<T>> {
333            let mutex = self.id.get_borrowed();
334            let result = self.inner.read();
335
336            map_lockresult(result, |inner| TracingRwLockGuard {
337                inner,
338                _mutex: mutex,
339            })
340        }
341
342        /// Wrapper for [`std::sync::RwLock::write`].
343        ///
344        /// # Panics
345        ///
346        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
347        /// dependency cycle, this method will panic.
348        #[track_caller]
349        pub fn write(&self) -> LockResult<RwLockWriteGuard<T>> {
350            let mutex = self.id.get_borrowed();
351            let result = self.inner.write();
352
353            map_lockresult(result, |inner| TracingRwLockGuard {
354                inner,
355                _mutex: mutex,
356            })
357        }
358
359        /// Wrapper for [`std::sync::RwLock::try_read`].
360        ///
361        /// # Panics
362        ///
363        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
364        /// dependency cycle, this method will panic.
365        #[track_caller]
366        pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<T>> {
367            let mutex = self.id.get_borrowed();
368            let result = self.inner.try_read();
369
370            map_trylockresult(result, |inner| TracingRwLockGuard {
371                inner,
372                _mutex: mutex,
373            })
374        }
375
376        /// Wrapper for [`std::sync::RwLock::try_write`].
377        ///
378        /// # Panics
379        ///
380        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
381        /// dependency cycle, this method will panic.
382        #[track_caller]
383        pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<T>> {
384            let mutex = self.id.get_borrowed();
385            let result = self.inner.try_write();
386
387            map_trylockresult(result, |inner| TracingRwLockGuard {
388                inner,
389                _mutex: mutex,
390            })
391        }
392
393        /// Return a mutable reference to the underlying data.
394        ///
395        /// This method does not block as the locking is handled compile-time by the type system.
396        pub fn get_mut(&mut self) -> LockResult<&mut T> {
397            self.inner.get_mut()
398        }
399
400        /// Unwrap the mutex and return its inner value.
401        pub fn into_inner(self) -> LockResult<T> {
402            self.inner.into_inner()
403        }
404    }
405
406    impl<T> From<T> for RwLock<T> {
407        fn from(t: T) -> Self {
408            Self::new(t)
409        }
410    }
411
412    impl<'a, L, T> Deref for TracingRwLockGuard<'a, L>
413    where
414        L: Deref<Target = T>,
415    {
416        type Target = T;
417
418        fn deref(&self) -> &Self::Target {
419            self.inner.deref()
420        }
421    }
422
423    impl<'a, T, L> DerefMut for TracingRwLockGuard<'a, L>
424    where
425        L: Deref<Target = T> + DerefMut,
426    {
427        fn deref_mut(&mut self) -> &mut Self::Target {
428            self.inner.deref_mut()
429        }
430    }
431
432    /// Wrapper around [`std::sync::Once`].
433    ///
434    /// Refer to the [crate-level][`crate`] documentaiton for the differences between this struct
435    /// and the one it wraps.
436    #[derive(Debug)]
437    pub struct Once {
438        inner: sync::Once,
439        mutex_id: LazyMutexId,
440    }
441
442    impl Once {
443        /// Create a new `Once` value.
444        pub const fn new() -> Self {
445            Self {
446                inner: sync::Once::new(),
447                mutex_id: LazyMutexId::new(),
448            }
449        }
450
451        /// Wrapper for [`std::sync::Once::call_once`].
452        ///
453        /// # Panics
454        ///
455        /// In addition to the panics that `Once` can cause, this method will panic if calling it
456        /// introduces a cycle in the lock dependency graph.
457        pub fn call_once<F>(&self, f: F)
458        where
459            F: FnOnce(),
460        {
461            let _guard = self.mutex_id.get_borrowed();
462            self.inner.call_once(f);
463        }
464
465        /// Performs the same operation as [`call_once`][Once::call_once] except it ignores
466        /// poisoning.
467        ///
468        /// # Panics
469        ///
470        /// This method participates in lock dependency tracking. If acquiring this lock introduces a
471        /// dependency cycle, this method will panic.
472        pub fn call_once_force<F>(&self, f: F)
473        where
474            F: FnOnce(&OnceState),
475        {
476            let _guard = self.mutex_id.get_borrowed();
477            self.inner.call_once_force(f);
478        }
479
480        /// Returns true if some `call_once` has completed successfully.
481        pub fn is_completed(&self) -> bool {
482            self.inner.is_completed()
483        }
484    }
485
486    /// Wrapper for [`std::sync::OnceLock`]
487    ///
488    /// The exact locking behaviour of [`std::sync::OnceLock`] is currently undefined, but may
489    /// deadlock in the event of reentrant initialization attempts. This wrapper participates in
490    /// cycle detection as normal and will therefore panic in the event of reentrancy.
491    ///
492    /// Most of this primitive's methods do not involve locking and as such are simply passed
493    /// through to the inner implementation.
494    ///
495    /// # Examples
496    ///
497    /// ```
498    /// use tracing_mutex::stdsync::tracing::OnceLock;
499    ///
500    /// static LOCK: OnceLock<i32> = OnceLock::new();
501    /// assert!(LOCK.get().is_none());
502    ///
503    /// std::thread::spawn(|| {
504    ///    let value: &i32 = LOCK.get_or_init(|| 42);
505    ///    assert_eq!(value, &42);
506    /// }).join().unwrap();
507    ///
508    /// let value: Option<&i32> = LOCK.get();
509    /// assert_eq!(value, Some(&42));
510    /// ```
511    #[derive(Debug)]
512    pub struct OnceLock<T> {
513        id: LazyMutexId,
514        inner: sync::OnceLock<T>,
515    }
516
517    // N.B. this impl inlines everything that directly calls the inner implementation as there
518    // should be 0 overhead to doing so.
519    impl<T> OnceLock<T> {
520        /// Creates a new empty cell
521        pub const fn new() -> Self {
522            Self {
523                id: LazyMutexId::new(),
524                inner: sync::OnceLock::new(),
525            }
526        }
527
528        /// Gets a reference to the underlying value.
529        ///
530        /// This method does not attempt to lock and therefore does not participate in cycle
531        /// detection.
532        #[inline]
533        pub fn get(&self) -> Option<&T> {
534            self.inner.get()
535        }
536
537        /// Gets a mutable reference to the underlying value.
538        ///
539        /// This method does not attempt to lock and therefore does not participate in cycle
540        /// detection.
541        #[inline]
542        pub fn get_mut(&mut self) -> Option<&mut T> {
543            self.inner.get_mut()
544        }
545
546        /// Sets the contents of this cell to the underlying value
547        ///
548        /// As this method may block until initialization is complete, it participates in cycle
549        /// detection.
550        pub fn set(&self, value: T) -> Result<(), T> {
551            let _guard = self.id.get_borrowed();
552
553            self.inner.set(value)
554        }
555
556        /// Gets the contents of the cell, initializing it with `f` if the cell was empty.
557        ///
558        /// This method participates in cycle detection. Reentrancy is considered a cycle.
559        pub fn get_or_init<F>(&self, f: F) -> &T
560        where
561            F: FnOnce() -> T,
562        {
563            let _guard = self.id.get_borrowed();
564            self.inner.get_or_init(f)
565        }
566
567        /// Takes the value out of this `OnceLock`, moving it back to an uninitialized state.
568        ///
569        /// This method does not attempt to lock and therefore does not participate in cycle
570        /// detection.
571        #[inline]
572        pub fn take(&mut self) -> Option<T> {
573            self.inner.take()
574        }
575
576        /// Consumes the `OnceLock`, returning the wrapped value. Returns None if the cell was
577        /// empty.
578        ///
579        /// This method does not attempt to lock and therefore does not participate in cycle
580        /// detection.
581        #[inline]
582        pub fn into_inner(mut self) -> Option<T> {
583            self.take()
584        }
585    }
586
587    impl<T> Default for OnceLock<T> {
588        #[inline]
589        fn default() -> Self {
590            Self::new()
591        }
592    }
593
594    impl<T: PartialEq> PartialEq for OnceLock<T> {
595        #[inline]
596        fn eq(&self, other: &Self) -> bool {
597            self.inner == other.inner
598        }
599    }
600
601    impl<T: Eq> Eq for OnceLock<T> {}
602
603    impl<T: Clone> Clone for OnceLock<T> {
604        fn clone(&self) -> Self {
605            Self {
606                id: LazyMutexId::new(),
607                inner: self.inner.clone(),
608            }
609        }
610    }
611
612    impl<T> From<T> for OnceLock<T> {
613        #[inline]
614        fn from(value: T) -> Self {
615            Self {
616                id: LazyMutexId::new(),
617                inner: sync::OnceLock::from(value),
618            }
619        }
620    }
621
622    #[cfg(test)]
623    mod tests {
624        use std::sync::Arc;
625        use std::thread;
626
627        use super::*;
628
629        #[test]
630        fn test_mutex_usage() {
631            let mutex = Arc::new(Mutex::new(0));
632
633            assert_eq!(*mutex.lock().unwrap(), 0);
634            *mutex.lock().unwrap() = 1;
635            assert_eq!(*mutex.lock().unwrap(), 1);
636
637            let mutex_clone = mutex.clone();
638
639            let _guard = mutex.lock().unwrap();
640
641            // Now try to cause a blocking exception in another thread
642            let handle = thread::spawn(move || {
643                let result = mutex_clone.try_lock().unwrap_err();
644
645                assert!(matches!(result, TryLockError::WouldBlock));
646            });
647
648            handle.join().unwrap();
649        }
650
651        #[test]
652        fn test_rwlock_usage() {
653            let rwlock = Arc::new(RwLock::new(0));
654
655            assert_eq!(*rwlock.read().unwrap(), 0);
656            assert_eq!(*rwlock.write().unwrap(), 0);
657            *rwlock.write().unwrap() = 1;
658            assert_eq!(*rwlock.read().unwrap(), 1);
659            assert_eq!(*rwlock.write().unwrap(), 1);
660
661            let rwlock_clone = rwlock.clone();
662
663            let _read_lock = rwlock.read().unwrap();
664
665            // Now try to cause a blocking exception in another thread
666            let handle = thread::spawn(move || {
667                let write_result = rwlock_clone.try_write().unwrap_err();
668
669                assert!(matches!(write_result, TryLockError::WouldBlock));
670
671                // Should be able to get a read lock just fine.
672                let _read_lock = rwlock_clone.read().unwrap();
673            });
674
675            handle.join().unwrap();
676        }
677
678        #[test]
679        fn test_once_usage() {
680            let once = Arc::new(Once::new());
681            let once_clone = once.clone();
682
683            assert!(!once.is_completed());
684
685            let handle = thread::spawn(move || {
686                assert!(!once_clone.is_completed());
687
688                once_clone.call_once(|| {});
689
690                assert!(once_clone.is_completed());
691            });
692
693            handle.join().unwrap();
694
695            assert!(once.is_completed());
696        }
697
698        #[test]
699        #[should_panic(expected = "Found cycle in mutex dependency graph")]
700        fn test_detect_cycle() {
701            let a = Mutex::new(());
702            let b = Mutex::new(());
703
704            let hold_a = a.lock().unwrap();
705            let _ = b.lock();
706
707            drop(hold_a);
708
709            let _hold_b = b.lock().unwrap();
710            let _ = a.lock();
711        }
712    }
713}