1use core::marker::PhantomData;
6use core::pin::Pin;
7use pin_init::{PinInit, pin_data, pin_init, pin_init_from_closure, pinned_drop};
8
9use crate::{LockToken, RawLock, RawMutex};
10use lockdep::LockClass;
11
12#[repr(transparent)] #[pin_data]
18pub struct KMutex<Class: LockClass, M: RawLock = RawMutex> {
19 #[pin]
20 mutex: M,
21 _marker: PhantomData<Class>,
22}
23
24impl<Class: LockClass, M: RawLock> KMutex<Class, M> {
25 pub const fn new(mutex: M) -> Self {
27 Self { mutex, _marker: PhantomData }
28 }
29
30 pub fn init() -> impl PinInit<Self, core::convert::Infallible> {
32 pin_init!(Self {
33 mutex <- M::init(),
34 _marker: PhantomData,
35 })
36 }
37
38 #[inline]
40 pub fn lock(&self) -> impl PinInit<KMutexGuard<'_, Class, M>, core::convert::Infallible> {
41 KMutexGuard::new(self)
42 }
43}
44
45impl<Class: LockClass, M: RawLock> core::fmt::Debug for KMutex<Class, M> {
46 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
47 f.debug_struct("KMutex").field("class", &core::any::type_name::<Class>()).finish()
48 }
49}
50
51#[repr(C)]
56#[pin_data(PinnedDrop)]
57pub struct KMutexGuard<'a, Class: LockClass, M: RawLock = RawMutex> {
58 mutex: &'a KMutex<Class, M>,
59
60 #[pin]
61 lock_entry: M::LockEntry,
62
63 state: M::GuardState,
64
65 token: LockToken<'a, Class>,
66}
67
68impl<'a, Class: LockClass, M: RawLock> KMutexGuard<'a, Class, M> {
69 pub fn new(mutex: &'a KMutex<Class, M>) -> impl PinInit<Self, core::convert::Infallible> {
71 unsafe {
74 pin_init_from_closure(move |this: *mut Self| -> Result<(), core::convert::Infallible> {
75 let mutex_addr = core::ptr::addr_of_mut!((*this).mutex);
79 core::ptr::write(mutex_addr, mutex);
80
81 let entry_addr = core::ptr::addr_of_mut!((*this).lock_entry);
82 core::ptr::write(entry_addr, M::LockEntry::default());
83
84 let state = mutex.mutex.acquire(Class::ID, entry_addr);
85
86 let state_addr = core::ptr::addr_of_mut!((*this).state);
87 core::ptr::write(state_addr, state);
88
89 let token_addr = core::ptr::addr_of_mut!((*this).token);
90 core::ptr::write(token_addr, LockToken::new());
91
92 Ok(())
93 })
94 }
95 }
96
97 #[inline]
99 pub fn token(&self) -> &LockToken<'a, Class> {
100 &self.token
101 }
102
103 #[inline]
105 pub fn token_mut(self: Pin<&mut Self>) -> &mut LockToken<'a, Class> {
106 let me = unsafe { self.get_unchecked_mut() };
109 &mut me.token
110 }
111}
112
113#[pinned_drop]
114impl<'a, Class: LockClass, M: RawLock> PinnedDrop for KMutexGuard<'a, Class, M> {
115 fn drop(self: Pin<&mut Self>) {
119 unsafe {
120 let me = self.get_unchecked_mut();
121 let entry_addr = &mut me.lock_entry as *mut _;
122 me.mutex.mutex.release(entry_addr, me.state);
123 }
124 }
125}
126
127#[cfg(not(feature = "kernel"))]
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::{KCell, RawMutex};
132 use lockdep::LockClass;
133 use pin_init::{pin_init, stack_pin_init};
134
135 struct MyClass;
136 impl LockClass for MyClass {
137 const ID: *mut core::ffi::c_void = core::ptr::null_mut();
138 }
139
140 #[pin_init::pin_data]
141 struct MyStruct {
142 #[pin]
143 mu: KMutex<MyClass>,
144 data1: KCell<u32, MyClass>,
145 data2: KCell<i32, MyClass>,
146 }
147
148 #[test]
149 fn test_basic_token_access() {
150 stack_pin_init!(let s = pin_init!(MyStruct {
151 mu <- KMutex::init(),
152 data1: KCell::new(10),
153 data2: KCell::new(-5),
154 }));
155
156 lock!(let mut guard = s.mu.lock());
157
158 unsafe {
159 assert_eq!(*s.data1.get(guard.token()), 10);
160 assert_eq!(*s.data2.get(guard.token()), -5);
161 }
162 unsafe {
163 let token_mut = guard.as_mut().token_mut();
164 *s.data1.get_mut(token_mut) = 20;
165 assert_eq!(*s.data1.get(guard.token()), 20);
166 }
167 unsafe {
168 let token_mut = guard.as_mut().token_mut();
169 let p1 = s.data1.as_mut_ptr(token_mut);
170 let p2 = s.data2.as_mut_ptr(token_mut);
171 *p1 = 30;
172 *p2 = -10;
173 }
174 unsafe {
175 assert_eq!(*s.data1.get(guard.token()), 30);
176 assert_eq!(*s.data2.get(guard.token()), -10);
177 }
178 }
179
180 #[test]
181 fn test_kmutex_init() {
182 stack_pin_init!(let mu = KMutex::<MyClass>::init());
183 lock!(mu.lock());
184 }
185
186 #[test]
187 fn test_kmutex_debug() {
188 extern crate std;
189 stack_pin_init!(let mu = KMutex::<MyClass>::init());
190 let debug_str = std::format!("{:?}", mu);
191 assert!(debug_str.contains("KMutex"));
192 }
193
194 use crate::guarded;
195
196 #[guarded]
197 struct MyGuardedStruct {
198 #[mutex]
199 mu: KMutex,
200 #[guarded_by(mu)]
201 data1: u32,
202 #[guarded_by(mu)]
203 data2: i32,
204 }
205
206 #[test]
207 fn test_macro_guarded() {
208 stack_pin_init!(let s = pin_init!(MyGuardedStruct {
209 mu <- KMutex::init(),
210 data1: 100.into(),
211 data2: (-50).into(),
212 }));
213
214 lock!(let mut guard = s.lock_mu());
215
216 assert_eq!(*guard.data1(), 100);
217 assert_eq!(*guard.data2(), -50);
218
219 *guard.as_mut().data1_mut() = 200;
220 assert_eq!(*guard.data1(), 200);
221 {
222 let fields = guard.fields();
223 assert_eq!(*fields.data1, 200);
224 assert_eq!(*fields.data2, -50);
225 }
226 {
227 let fields = guard.as_mut().fields_mut();
228 *fields.data1 = 300;
229 *fields.data2 = -100;
230 }
231 assert_eq!(*guard.data1(), 300);
232 assert_eq!(*guard.data2(), -100);
233 }
234
235 #[guarded]
236 struct MyMultiGuardedStruct {
237 #[mutex]
238 mu1: KMutex,
239 #[mutex]
240 mu2: KMutex,
241 #[guarded_by(mu1)]
242 data1: u32,
243 #[guarded_by(mu2)]
244 data2: i32,
245 }
246
247 #[test]
248 fn test_macro_multi_guarded() {
249 stack_pin_init!(let s = pin_init!(MyMultiGuardedStruct {
250 mu1 <- KMutex::init(),
251 mu2 <- KMutex::init(),
252 data1: 10.into(),
253 data2: 20.into(),
254 }));
255
256 lock!(let mut guard1 = s.lock_mu1());
257 lock!(let mut guard2 = s.lock_mu2());
258
259 assert_eq!(*guard1.data1(), 10);
260 assert_eq!(*guard2.data2(), 20);
261 *guard1.as_mut().data1_mut() = 15;
262 *guard2.as_mut().data2_mut() = 25;
263 assert_eq!(*guard1.data1(), 15);
264 assert_eq!(*guard2.data2(), 25);
265 }
266
267 #[guarded]
268 struct MyDefaultGuardedStruct {
269 #[mutex]
270 mu: KMutex,
271 #[guarded_by(mu)]
272 data: u32,
273 }
274
275 #[test]
276 fn test_derive_default_guarded() {
277 stack_pin_init!(let s = pin_init!(MyDefaultGuardedStruct {
278 mu <- KMutex::init(),
279 data: 0.into(),
280 }));
281 lock!(let guard = s.lock_mu());
282 assert_eq!(*guard.data(), 0);
283 }
284
285 #[guarded]
286 struct MyGenericLockGuardedStruct<L: RawLock> {
287 #[mutex]
288 mu: KMutex<L>,
289 #[guarded_by(mu)]
290 data: u32,
291 }
292
293 #[test]
294 fn test_macro_generic_lock_guarded() {
295 stack_pin_init!(let s = pin_init!(MyGenericLockGuardedStruct::<RawMutex> {
296 mu <- KMutex::init(),
297 data: 100.into(),
298 }));
299
300 lock!(let guard = s.lock_mu());
301 assert_eq!(*guard.data(), 100);
302 }
303
304 #[guarded]
305 struct MyGenericGuardedStruct<T> {
306 #[mutex]
307 mu: KMutex,
308 #[guarded_by(mu)]
309 data: T,
310 }
311
312 #[test]
313 fn test_macro_generic_guarded() {
314 stack_pin_init!(let s = pin_init!(MyGenericGuardedStruct::<u32> {
315 mu <- KMutex::init(),
316 data: 0.into(),
317 }));
318 lock!(let mut guard = s.lock_mu());
319 assert_eq!(*guard.data(), 0);
320
321 *guard.as_mut().data_mut() = 42;
322 assert_eq!(*guard.data(), 42);
323
324 let fields = guard.as_mut().fields_mut();
325 *fields.data = 100;
326
327 let fields_shared = guard.fields();
328 assert_eq!(*fields_shared.data, 100);
329 }
330
331 #[guarded]
332 struct MyExplicitParentGuardedStruct {
333 #[mutex]
334 mu: KMutex,
335 #[guarded_by(mu)]
336 data: u32,
337 pub label: &'static str,
338 }
339
340 impl MyExplicitParentGuardedStruct {
341 pub fn has_label(&self) -> bool {
342 !self.label.is_empty()
343 }
344 }
345
346 impl<'a> MyExplicitParentGuardedStructMuGuard<'a> {
347 pub fn process_with_context(self: Pin<&mut Self>) {
348 let me = unsafe { self.get_unchecked_mut() };
349 let has_label = me.parent.has_label();
350 let label = me.parent.label;
351 if has_label && label == "apply_update" {
352 unsafe {
353 let mut_self = Pin::new_unchecked(me);
354 let fields = mut_self.fields_mut();
355 *fields.data = 100;
356 }
357 }
358 }
359 }
360
361 #[test]
362 fn test_macro_guard_explicit_parent_access() {
363 stack_pin_init!(let s = pin_init!(MyExplicitParentGuardedStruct {
364 mu <- KMutex::init(),
365 data: 0.into(),
366 label: "apply_update",
367 }));
368
369 {
370 lock!(let mut guard = s.lock_mu());
371 guard.as_mut().process_with_context();
372 }
373
374 lock!(let guard = s.lock_mu());
375 assert_eq!(*guard.data(), 100);
376 }
377}