Skip to main content

ksync/
kmutex.rs

1// Copyright 2026 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::marker::PhantomData;
6use core::pin::Pin;
7use pin_init::{PinInit, pin_data, pin_init, pin_init_from_closure, pinned_drop};
8
9use crate::{LockToken, RawLock, RawMutex};
10use lockdep::LockClass;
11
12/// A safe, Zircon-compatible mutual exclusion lock supporting compile-time order validation.
13///
14/// `KMutex` wraps a platform-specific `RawLock` abstraction. It is pinned in memory to support FFI
15/// loop-detector active list registrations safely under the lock class `Class`.
16#[repr(transparent)] // Ensure KMutex has the same layout as the underlying RawLock M.
17#[pin_data]
18pub struct KMutex<Class: LockClass, M: RawLock = RawMutex> {
19    #[pin]
20    mutex: M,
21    _marker: PhantomData<Class>,
22}
23
24impl<Class: LockClass, M: RawLock> KMutex<Class, M> {
25    /// Create a new KMutex with a pre-initialized raw lock.
26    pub const fn new(mutex: M) -> Self {
27        Self { mutex, _marker: PhantomData }
28    }
29
30    /// Safe dynamic initialization of the validation lock inside pin context.
31    pub fn init() -> impl PinInit<Self, core::convert::Infallible> {
32        pin_init!(Self {
33            mutex <- M::init(),
34            _marker: PhantomData,
35        })
36    }
37
38    /// Acquires the lock and registers the active loop node.
39    #[inline]
40    pub fn lock(&self) -> impl PinInit<KMutexGuard<'_, Class, M>, core::convert::Infallible> {
41        KMutexGuard::new(self)
42    }
43}
44
45impl<Class: LockClass, M: RawLock> core::fmt::Debug for KMutex<Class, M> {
46    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
47        f.debug_struct("KMutex").field("class", &core::any::type_name::<Class>()).finish()
48    }
49}
50
51/// A validation guard representing exclusive lock ownership and active list participation.
52///
53/// The guard is pinned in memory to ensure that its `lock_entry` pointer remains safe and valid
54/// inside the C++ loop detector active thread list.
55#[repr(C)]
56#[pin_data(PinnedDrop)]
57pub struct KMutexGuard<'a, Class: LockClass, M: RawLock = RawMutex> {
58    mutex: &'a KMutex<Class, M>,
59
60    #[pin]
61    lock_entry: M::LockEntry,
62
63    state: M::GuardState,
64
65    token: LockToken<'a, Class>,
66}
67
68impl<'a, Class: LockClass, M: RawLock> KMutexGuard<'a, Class, M> {
69    /// Creates a new stack-pinned validation guard initialization block.
70    pub fn new(mutex: &'a KMutex<Class, M>) -> impl PinInit<Self, core::convert::Infallible> {
71        // SAFETY: The closure correctly initializes all fields of the allocated `KMutexGuard`
72        // and satisfies all safety requirements of `pin_init_from_closure`.
73        unsafe {
74            pin_init_from_closure(move |this: *mut Self| -> Result<(), core::convert::Infallible> {
75                // SAFETY: `this` is a valid pointer to uninitialized memory allocated for
76                // `KMutexGuard`.
77
78                let mutex_addr = core::ptr::addr_of_mut!((*this).mutex);
79                core::ptr::write(mutex_addr, mutex);
80
81                let entry_addr = core::ptr::addr_of_mut!((*this).lock_entry);
82                core::ptr::write(entry_addr, M::LockEntry::default());
83
84                let state = mutex.mutex.acquire(Class::ID, entry_addr);
85
86                let state_addr = core::ptr::addr_of_mut!((*this).state);
87                core::ptr::write(state_addr, state);
88
89                let token_addr = core::ptr::addr_of_mut!((*this).token);
90                core::ptr::write(token_addr, LockToken::new());
91
92                Ok(())
93            })
94        }
95    }
96
97    /// Returns a shared reference to the lock proof `LockToken`.
98    #[inline]
99    pub fn token(&self) -> &LockToken<'a, Class> {
100        &self.token
101    }
102
103    /// Returns a mutable reference to the lock proof `LockToken` inside this pinned projection.
104    #[inline]
105    pub fn token_mut(self: Pin<&mut Self>) -> &mut LockToken<'a, Class> {
106        // SAFETY: Modifying the non-pinned raw `token` field does not violate pinning invariants
107        // since the token has no drop logic or pointer-location sensitivity.
108        let me = unsafe { self.get_unchecked_mut() };
109        &mut me.token
110    }
111}
112
113#[pinned_drop]
114impl<'a, Class: LockClass, M: RawLock> PinnedDrop for KMutexGuard<'a, Class, M> {
115    // SAFETY: The stack slot `lock_entry` remains valid and pinned on the stack until this drop
116    // block completes. Accessing the fields directly to release the raw lock and remove the
117    // active list node is safe and correct under the current thread context.
118    fn drop(self: Pin<&mut Self>) {
119        unsafe {
120            let me = self.get_unchecked_mut();
121            let entry_addr = &mut me.lock_entry as *mut _;
122            me.mutex.mutex.release(entry_addr, me.state);
123        }
124    }
125}
126
127#[cfg(not(feature = "kernel"))]
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::{KCell, RawMutex};
132    use lockdep::LockClass;
133    use pin_init::{pin_init, stack_pin_init};
134
135    struct MyClass;
136    impl LockClass for MyClass {
137        const ID: *mut core::ffi::c_void = core::ptr::null_mut();
138    }
139
140    #[pin_init::pin_data]
141    struct MyStruct {
142        #[pin]
143        mu: KMutex<MyClass>,
144        data1: KCell<u32, MyClass>,
145        data2: KCell<i32, MyClass>,
146    }
147
148    #[test]
149    fn test_basic_token_access() {
150        stack_pin_init!(let s = pin_init!(MyStruct {
151            mu <- KMutex::init(),
152            data1: KCell::new(10),
153            data2: KCell::new(-5),
154        }));
155
156        lock!(let mut guard = s.mu.lock());
157
158        unsafe {
159            assert_eq!(*s.data1.get(guard.token()), 10);
160            assert_eq!(*s.data2.get(guard.token()), -5);
161        }
162        unsafe {
163            let token_mut = guard.as_mut().token_mut();
164            *s.data1.get_mut(token_mut) = 20;
165            assert_eq!(*s.data1.get(guard.token()), 20);
166        }
167        unsafe {
168            let token_mut = guard.as_mut().token_mut();
169            let p1 = s.data1.as_mut_ptr(token_mut);
170            let p2 = s.data2.as_mut_ptr(token_mut);
171            *p1 = 30;
172            *p2 = -10;
173        }
174        unsafe {
175            assert_eq!(*s.data1.get(guard.token()), 30);
176            assert_eq!(*s.data2.get(guard.token()), -10);
177        }
178    }
179
180    #[test]
181    fn test_kmutex_init() {
182        stack_pin_init!(let mu = KMutex::<MyClass>::init());
183        lock!(mu.lock());
184    }
185
186    #[test]
187    fn test_kmutex_debug() {
188        extern crate std;
189        stack_pin_init!(let mu = KMutex::<MyClass>::init());
190        let debug_str = std::format!("{:?}", mu);
191        assert!(debug_str.contains("KMutex"));
192    }
193
194    use crate::guarded;
195
196    #[guarded]
197    struct MyGuardedStruct {
198        #[mutex]
199        mu: KMutex,
200        #[guarded_by(mu)]
201        data1: u32,
202        #[guarded_by(mu)]
203        data2: i32,
204    }
205
206    #[test]
207    fn test_macro_guarded() {
208        stack_pin_init!(let s = pin_init!(MyGuardedStruct {
209            mu <- KMutex::init(),
210            data1: 100.into(),
211            data2: (-50).into(),
212        }));
213
214        lock!(let mut guard = s.lock_mu());
215
216        assert_eq!(*guard.data1(), 100);
217        assert_eq!(*guard.data2(), -50);
218
219        *guard.as_mut().data1_mut() = 200;
220        assert_eq!(*guard.data1(), 200);
221        {
222            let fields = guard.fields();
223            assert_eq!(*fields.data1, 200);
224            assert_eq!(*fields.data2, -50);
225        }
226        {
227            let fields = guard.as_mut().fields_mut();
228            *fields.data1 = 300;
229            *fields.data2 = -100;
230        }
231        assert_eq!(*guard.data1(), 300);
232        assert_eq!(*guard.data2(), -100);
233    }
234
235    #[guarded]
236    struct MyMultiGuardedStruct {
237        #[mutex]
238        mu1: KMutex,
239        #[mutex]
240        mu2: KMutex,
241        #[guarded_by(mu1)]
242        data1: u32,
243        #[guarded_by(mu2)]
244        data2: i32,
245    }
246
247    #[test]
248    fn test_macro_multi_guarded() {
249        stack_pin_init!(let s = pin_init!(MyMultiGuardedStruct {
250            mu1 <- KMutex::init(),
251            mu2 <- KMutex::init(),
252            data1: 10.into(),
253            data2: 20.into(),
254        }));
255
256        lock!(let mut guard1 = s.lock_mu1());
257        lock!(let mut guard2 = s.lock_mu2());
258
259        assert_eq!(*guard1.data1(), 10);
260        assert_eq!(*guard2.data2(), 20);
261        *guard1.as_mut().data1_mut() = 15;
262        *guard2.as_mut().data2_mut() = 25;
263        assert_eq!(*guard1.data1(), 15);
264        assert_eq!(*guard2.data2(), 25);
265    }
266
267    #[guarded]
268    struct MyDefaultGuardedStruct {
269        #[mutex]
270        mu: KMutex,
271        #[guarded_by(mu)]
272        data: u32,
273    }
274
275    #[test]
276    fn test_derive_default_guarded() {
277        stack_pin_init!(let s = pin_init!(MyDefaultGuardedStruct {
278            mu <- KMutex::init(),
279            data: 0.into(),
280        }));
281        lock!(let guard = s.lock_mu());
282        assert_eq!(*guard.data(), 0);
283    }
284
285    #[guarded]
286    struct MyGenericLockGuardedStruct<L: RawLock> {
287        #[mutex]
288        mu: KMutex<L>,
289        #[guarded_by(mu)]
290        data: u32,
291    }
292
293    #[test]
294    fn test_macro_generic_lock_guarded() {
295        stack_pin_init!(let s = pin_init!(MyGenericLockGuardedStruct::<RawMutex> {
296            mu <- KMutex::init(),
297            data: 100.into(),
298        }));
299
300        lock!(let guard = s.lock_mu());
301        assert_eq!(*guard.data(), 100);
302    }
303
304    #[guarded]
305    struct MyGenericGuardedStruct<T> {
306        #[mutex]
307        mu: KMutex,
308        #[guarded_by(mu)]
309        data: T,
310    }
311
312    #[test]
313    fn test_macro_generic_guarded() {
314        stack_pin_init!(let s = pin_init!(MyGenericGuardedStruct::<u32> {
315            mu <- KMutex::init(),
316            data: 0.into(),
317        }));
318        lock!(let mut guard = s.lock_mu());
319        assert_eq!(*guard.data(), 0);
320
321        *guard.as_mut().data_mut() = 42;
322        assert_eq!(*guard.data(), 42);
323
324        let fields = guard.as_mut().fields_mut();
325        *fields.data = 100;
326
327        let fields_shared = guard.fields();
328        assert_eq!(*fields_shared.data, 100);
329    }
330
331    #[guarded]
332    struct MyExplicitParentGuardedStruct {
333        #[mutex]
334        mu: KMutex,
335        #[guarded_by(mu)]
336        data: u32,
337        pub label: &'static str,
338    }
339
340    impl MyExplicitParentGuardedStruct {
341        pub fn has_label(&self) -> bool {
342            !self.label.is_empty()
343        }
344    }
345
346    impl<'a> MyExplicitParentGuardedStructMuGuard<'a> {
347        pub fn process_with_context(self: Pin<&mut Self>) {
348            let me = unsafe { self.get_unchecked_mut() };
349            let has_label = me.parent.has_label();
350            let label = me.parent.label;
351            if has_label && label == "apply_update" {
352                unsafe {
353                    let mut_self = Pin::new_unchecked(me);
354                    let fields = mut_self.fields_mut();
355                    *fields.data = 100;
356                }
357            }
358        }
359    }
360
361    #[test]
362    fn test_macro_guard_explicit_parent_access() {
363        stack_pin_init!(let s = pin_init!(MyExplicitParentGuardedStruct {
364            mu <- KMutex::init(),
365            data: 0.into(),
366            label: "apply_update",
367        }));
368
369        {
370            lock!(let mut guard = s.lock_mu());
371            guard.as_mut().process_with_context();
372        }
373
374        lock!(let guard = s.lock_mu());
375        assert_eq!(*guard.data(), 100);
376    }
377}