1use crate::task::{CurrentTask, WaitQueue, Waiter};
6use crate::vfs::{FdTableId, FileObject, FileObjectId};
7use starnix_sync::{Locked, Mutex, Unlocked};
8use starnix_uapi::errors::{EAGAIN, Errno};
9use starnix_uapi::{
10 __kernel_off_t, F_GETLK, F_GETLK64, F_OFD_GETLK, F_OFD_SETLK, F_OFD_SETLKW, F_RDLCK, F_SETLK,
11 F_SETLK64, F_SETLKW, F_SETLKW64, F_UNLCK, F_WRLCK, SEEK_CUR, SEEK_END, SEEK_SET, c_short,
12 errno, error, pid_t, uapi,
13};
14use std::collections::BTreeSet;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17enum RecordLength {
18 Value(usize),
19 Infinite,
20}
21
22impl RecordLength {
23 fn new(value: usize) -> Self {
24 if value == 0 { Self::Infinite } else { Self::Value(value) }
25 }
26 fn value(&self) -> __kernel_off_t {
27 match self {
28 Self::Value(e) => *e as __kernel_off_t,
29 Self::Infinite => 0,
30 }
31 }
32}
33
34impl std::ops::Add<usize> for RecordLength {
35 type Output = Self;
36
37 fn add(self, element: usize) -> Self {
38 match self {
39 Self::Value(e) => Self::Value(e.saturating_add(element)),
40 Self::Infinite => Self::Infinite,
41 }
42 }
43}
44
45impl std::ops::Sub<usize> for RecordLength {
46 type Output = Option<Self>;
47
48 fn sub(self, element: usize) -> Option<Self> {
49 match self {
50 Self::Value(e) if e > element => Some(Self::Value(e - element)),
51 Self::Infinite => Some(Self::Infinite),
52 _ => None,
53 }
54 }
55}
56
57impl std::cmp::PartialEq<usize> for RecordLength {
58 fn eq(&self, other: &usize) -> bool {
59 match self {
60 Self::Value(e) => e == other,
61 Self::Infinite => false,
62 }
63 }
64}
65
66impl std::cmp::PartialOrd<usize> for RecordLength {
67 fn partial_cmp(&self, other: &usize) -> Option<std::cmp::Ordering> {
68 match self {
69 Self::Value(e) => e.partial_cmp(other),
70 Self::Infinite => Some(std::cmp::Ordering::Greater),
71 }
72 }
73}
74
75impl Ord for RecordLength {
76 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
77 if self == other {
78 std::cmp::Ordering::Equal
79 } else {
80 match self {
81 Self::Value(e1) => match other {
82 Self::Value(e2) => e1.cmp(e2),
83 Self::Infinite => std::cmp::Ordering::Less,
84 },
85 Self::Infinite => std::cmp::Ordering::Greater,
86 }
87 }
88 }
89}
90
91impl PartialOrd for RecordLength {
92 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
93 Some(self.cmp(other))
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
98struct RecordRange {
99 start: usize,
100 length: RecordLength,
101}
102
103impl RecordRange {
104 fn new(start: usize, length: usize) -> Self {
105 Self { start, length: RecordLength::new(length) }
106 }
107
108 fn build(flock: &uapi::flock, file: &FileObject) -> Result<RecordRange, Errno> {
112 let origin: __kernel_off_t = match flock.l_whence as u32 {
113 SEEK_SET => 0,
114 SEEK_CUR => *file.offset.lock(),
115 SEEK_END => file.node().info().size.try_into().map_err(|_| errno!(EINVAL))?,
116 _ => {
117 return error!(EINVAL);
118 }
119 };
120 let mut start = origin.checked_add(flock.l_start).ok_or_else(|| errno!(EOVERFLOW))?;
121 let mut length = flock.l_len;
122 if length < 0 {
123 start = start.checked_add(length).ok_or_else(|| errno!(EINVAL))?;
124 length = length.checked_neg().ok_or_else(|| errno!(EINVAL))?;
125 }
126 if start < 0 {
127 return error!(EINVAL);
128 }
129 Ok(Self::new(start as usize, length as usize))
130 }
131
132 fn end(&self) -> RecordLength {
133 self.length + self.start
134 }
135
136 fn intersects(&self, other: &RecordRange) -> bool {
137 let r1 = std::cmp::min(self, other);
138 let r2 = std::cmp::max(self, other);
139 r1.end() > r2.start
140 }
141}
142
143impl std::ops::Sub<RecordRange> for RecordRange {
144 type Output = Vec<Self>;
145
146 fn sub(self, other: RecordRange) -> Vec<RecordRange> {
147 if !self.intersects(&other) {
148 return vec![self];
149 }
150 let mut vec = Vec::with_capacity(2);
151 if self.start < other.start {
152 let length = std::cmp::min(RecordLength::Value(other.start - self.start), self.length);
153 vec.push(RecordRange { start: self.start, length });
154 }
155 if let RecordLength::Value(start) = other.end() {
156 let end = self.end();
157 if let Some(length) = end - start {
158 vec.push(RecordRange { start, length });
159 }
160 }
161 vec
162 }
163}
164
165impl std::ops::Add<RecordRange> for RecordRange {
166 type Output = Vec<Self>;
167
168 fn add(self, other: RecordRange) -> Vec<RecordRange> {
169 let r1 = std::cmp::min(self, other);
170 let r2 = std::cmp::max(self, other);
171 let r1_end = r1.end();
172 if r1_end < r2.start {
173 vec![r1, r2]
174 } else {
175 let end = std::cmp::max(r1_end, r2.end());
176 vec![RecordRange {
177 start: r1.start,
178 length: (end - r1.start).expect("Length is guaranteed to exist"),
179 }]
180 }
181 }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
185enum RecordLockType {
186 Read,
187 Write,
188}
189
190impl RecordLockType {
191 fn build(flock: &uapi::flock) -> Result<Option<Self>, Errno> {
192 match flock.l_type as u32 {
193 F_UNLCK => Ok(None),
194 F_RDLCK => Ok(Some(Self::Read)),
195 F_WRLCK => Ok(Some(Self::Write)),
196 _ => error!(EINVAL),
197 }
198 }
199
200 fn is_compatible(&self, other: RecordLockType) -> bool {
203 *self == Self::Read && other == Self::Read
204 }
205
206 fn has_permission(&self, file: &FileObject) -> bool {
207 match self {
208 Self::Read => file.can_read(),
209 Self::Write => file.can_write(),
210 }
211 }
212
213 fn value(&self) -> c_short {
214 match self {
215 Self::Read => F_RDLCK as c_short,
216 Self::Write => F_WRLCK as c_short,
217 }
218 }
219}
220
221#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
222pub enum RecordLockOwner {
223 FdTable(FdTableId),
224 FileObject(FileObjectId),
225}
226
227impl RecordLockOwner {
228 fn new(current_task: &CurrentTask, cmd: RecordLockCommand, file: &FileObject) -> Self {
229 if cmd.is_ofd() {
230 RecordLockOwner::FileObject(file.id)
231 } else {
232 RecordLockOwner::FdTable(current_task.files.id())
233 }
234 }
235}
236
237#[derive(Debug, Clone)]
238struct RecordLock {
239 pub owner: RecordLockOwner,
240 pub range: RecordRange,
241 pub lock_type: RecordLockType,
242 pub process_id: pid_t,
243}
244
245impl RecordLock {
246 fn as_tuple(&self) -> (RecordLockOwner, &RecordRange, RecordLockType) {
247 (self.owner, &self.range, self.lock_type)
248 }
249}
250
251impl PartialEq for RecordLock {
252 fn eq(&self, other: &Self) -> bool {
253 self.as_tuple() == other.as_tuple()
254 }
255}
256
257impl Eq for RecordLock {}
258
259impl Ord for RecordLock {
260 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
261 self.as_tuple().cmp(&other.as_tuple())
262 }
263}
264
265impl PartialOrd for RecordLock {
266 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
267 Some(self.cmp(other))
268 }
269}
270#[derive(Default, Debug)]
271pub struct RecordLocksState {
272 locks: BTreeSet<RecordLock>,
273 queue: WaitQueue,
274}
275
276impl RecordLocksState {
277 fn get_conflicting_lock(
280 &self,
281 owner: RecordLockOwner,
282 lock_type: RecordLockType,
283 range: &RecordRange,
284 ) -> Option<uapi::flock> {
285 for record in &self.locks {
286 if owner == record.owner {
287 continue;
288 }
289 if lock_type.is_compatible(record.lock_type) {
290 continue;
291 }
292 if range.intersects(&record.range) {
293 return Some(uapi::flock {
294 l_type: record.lock_type.value(),
295 l_whence: SEEK_SET as c_short,
296 l_start: record.range.start as __kernel_off_t,
297 l_len: record.range.length.value(),
298 l_pid: record.process_id,
299 ..Default::default()
300 });
301 }
302 }
303 None
304 }
305
306 fn apply_lock(
307 &mut self,
308 process_id: pid_t,
309 owner: RecordLockOwner,
310 lock_type: RecordLockType,
311 range: RecordRange,
312 ) -> Result<(), Errno> {
313 let mut owned_locks_in_range = Vec::new();
314 for lock in self.locks.iter().filter(|record| range.intersects(&record.range)) {
315 if owner == lock.owner {
316 owned_locks_in_range.push(lock.clone());
317 } else if !lock_type.is_compatible(lock.lock_type) {
318 return error!(EAGAIN);
320 }
321 }
322 let mut new_lock = RecordLock { owner, range, lock_type, process_id };
323 for lock in owned_locks_in_range {
324 self.locks.remove(&lock);
325 if lock.lock_type == lock_type {
326 let new_ranges = new_lock.range + lock.range;
327 assert!(new_ranges.len() == 1);
328 new_lock.range = new_ranges[0];
329 } else {
330 for range in lock.range - new_lock.range {
331 let mut remaining_lock = lock.clone();
332 remaining_lock.range = range;
333 self.locks.insert(remaining_lock);
334 }
335 }
336 }
337 self.locks.insert(new_lock);
338 self.queue.notify_all();
339 Ok(())
340 }
341
342 fn unlock(&mut self, owner: RecordLockOwner, range: RecordRange) -> Result<(), Errno> {
343 let intersection_locks: Vec<_> = self
344 .locks
345 .iter()
346 .filter(|record| owner == record.owner && range.intersects(&record.range))
347 .cloned()
348 .collect();
349 for lock in intersection_locks {
350 self.locks.remove(&lock);
351 for new_range in lock.range - range {
352 let mut new_lock = lock.clone();
353 new_lock.range = new_range;
354 self.locks.insert(new_lock);
355 }
356 }
357 self.queue.notify_all();
358 Ok(())
359 }
360
361 fn release_locks(&mut self, owner: RecordLockOwner) {
362 self.locks.retain(|lock| lock.owner != owner);
363 self.queue.notify_all();
364 }
365}
366
367#[allow(non_camel_case_types)]
368#[allow(clippy::upper_case_acronyms)]
369#[derive(Debug, Clone, Copy, PartialEq, Eq)]
370pub enum RecordLockCommand {
371 SETLK,
372 SETLKW,
373 GETLK,
374 OFD_GETLK,
375 OFD_SETLK,
376 OFD_SETLKW,
377}
378
379impl RecordLockCommand {
380 pub fn from_raw(cmd: u32) -> Option<Self> {
381 match cmd {
382 F_SETLK | F_SETLK64 => Some(RecordLockCommand::SETLK),
383 F_SETLKW | F_SETLKW64 => Some(RecordLockCommand::SETLKW),
384 F_GETLK | F_GETLK64 => Some(RecordLockCommand::GETLK),
385 F_OFD_GETLK => Some(RecordLockCommand::OFD_GETLK),
386 F_OFD_SETLK => Some(RecordLockCommand::OFD_SETLK),
387 F_OFD_SETLKW => Some(RecordLockCommand::OFD_SETLKW),
388 _ => None,
389 }
390 }
391
392 fn is_ofd(&self) -> bool {
393 match self {
394 RecordLockCommand::SETLK | RecordLockCommand::SETLKW | RecordLockCommand::GETLK => {
395 false
396 }
397 RecordLockCommand::OFD_GETLK
398 | RecordLockCommand::OFD_SETLK
399 | RecordLockCommand::OFD_SETLKW => true,
400 }
401 }
402
403 fn is_get(&self) -> bool {
404 *self == RecordLockCommand::GETLK || *self == RecordLockCommand::OFD_GETLK
405 }
406
407 fn is_blocking(&self) -> bool {
408 *self == RecordLockCommand::SETLKW || *self == RecordLockCommand::OFD_SETLKW
409 }
410}
411
412#[derive(Default, Debug)]
413pub struct RecordLocks {
414 state: Mutex<RecordLocksState>,
415}
416
417impl RecordLocks {
418 pub fn lock(
423 &self,
424 locked: &mut Locked<Unlocked>,
425 current_task: &CurrentTask,
426 file: &FileObject,
427 cmd: RecordLockCommand,
428 mut flock: uapi::flock,
429 ) -> Result<Option<uapi::flock>, Errno> {
430 if cmd.is_ofd() && flock.l_pid != 0 {
431 return error!(EINVAL);
432 }
433 let owner: RecordLockOwner = RecordLockOwner::new(current_task, cmd, file);
434 let lock_type = RecordLockType::build(&flock)?;
435 let range = RecordRange::build(&flock, file)?;
436 if cmd.is_get() {
437 let lock_type = lock_type.ok_or_else(|| errno!(EINVAL))?;
438 Ok(self.state.lock().get_conflicting_lock(owner, lock_type, &range).or_else(|| {
439 flock.l_type = F_UNLCK as c_short;
440 Some(flock)
441 }))
442 } else {
443 match lock_type {
444 Some(lock_type) => {
445 if !lock_type.has_permission(file) {
446 return error!(EBADF);
447 }
448 let blocking = cmd.is_blocking();
449 loop {
450 let mut state = self.state.lock();
451 let waiter = blocking.then(|| {
452 let waiter = Waiter::new();
453 state.queue.wait_async(&waiter);
454 waiter
455 });
456 let process_id =
457 if cmd.is_ofd() { -1 } else { current_task.thread_group().leader };
458 match state.apply_lock(process_id, owner, lock_type, range) {
459 Err(errno) if blocking && errno == EAGAIN => {
460 if let Some(waiter) = waiter {
462 std::mem::drop(state);
463 waiter.wait(locked, current_task)?;
464 }
465 }
466 result => return result.map(|_| None),
467 }
468 }
469 }
470 None => {
471 self.state.lock().unlock(owner, range)?;
472 }
473 }
474 Ok(None)
475 }
476 }
477
478 pub fn release_locks(&self, owner: RecordLockOwner) {
479 self.state.lock().release_locks(owner);
480 }
481}
482
483#[cfg(test)]
484mod test {
485 use super::*;
486
487 #[::fuchsia::test]
488 fn test_range_intersects() {
489 let r1 = RecordRange::new(25, 3);
490 assert!(r1.intersects(&RecordRange::new(25, 1)));
491 assert!(r1.intersects(&RecordRange::new(0, 0)));
492 assert!(r1.intersects(&RecordRange::new(0, 60)));
493 assert!(r1.intersects(&RecordRange::new(27, 8)));
494 assert!(!r1.intersects(&RecordRange::new(28, 1)));
495 assert!(!r1.intersects(&RecordRange::new(29, 8)));
496 assert!(!r1.intersects(&RecordRange::new(29, 0)));
497 assert!(!r1.intersects(&RecordRange::new(0, 8)));
498 }
499
500 #[::fuchsia::test]
501 fn test_range_sub() {
502 let r1 = RecordRange::new(25, 3);
503 assert_eq!(r1 - RecordRange::new(0, 2), vec!(r1));
504 assert_eq!(r1 - RecordRange::new(29, 2), vec!(r1));
505 assert_eq!(r1 - RecordRange::new(29, 0), vec!(r1));
506 assert_eq!(r1 - RecordRange::new(20, 0), vec!());
507 assert_eq!(r1 - RecordRange::new(20, 12), vec!());
508 assert_eq!(r1 - RecordRange::new(20, 6), vec!(RecordRange::new(26, 2)));
509 assert_eq!(r1 - RecordRange::new(26, 3), vec!(RecordRange::new(25, 1)));
510 assert_eq!(r1 - RecordRange::new(26, 0), vec!(RecordRange::new(25, 1)));
511 assert_eq!(
512 r1 - RecordRange::new(26, 1),
513 vec!(RecordRange::new(25, 1), RecordRange::new(27, 1))
514 );
515
516 let r2 = RecordRange::new(25, 0);
517 assert_eq!(r2 - RecordRange::new(0, 2), vec!(r2));
518 assert_eq!(r2 - RecordRange::new(20, 0), vec!());
519 assert_eq!(r2 - RecordRange::new(20, 6), vec!(RecordRange::new(26, 0)));
520 assert_eq!(r2 - RecordRange::new(26, 0), vec!(RecordRange::new(25, 1)));
521 assert_eq!(
522 r2 - RecordRange::new(26, 1),
523 vec!(RecordRange::new(25, 1), RecordRange::new(27, 0))
524 );
525 }
526
527 #[::fuchsia::test]
528 fn test_range_add() {
529 let r1 = RecordRange::new(25, 3);
530 assert_eq!(r1 + RecordRange::new(0, 2), vec!(RecordRange::new(0, 2), r1));
531 assert_eq!(r1 + RecordRange::new(30, 2), vec!(r1, RecordRange::new(30, 2)));
532 assert_eq!(r1 + RecordRange::new(30, 0), vec!(r1, RecordRange::new(30, 0)));
533 assert_eq!(r1 + RecordRange::new(22, 3), vec!(RecordRange::new(22, 6)));
534 assert_eq!(r1 + RecordRange::new(22, 4), vec!(RecordRange::new(22, 6)));
535 assert_eq!(r1 + RecordRange::new(22, 8), vec!(RecordRange::new(22, 8)));
536 assert_eq!(r1 + RecordRange::new(22, 0), vec!(RecordRange::new(22, 0)));
537 assert_eq!(r1 + RecordRange::new(26, 1), vec!(RecordRange::new(25, 3)));
538 assert_eq!(r1 + RecordRange::new(26, 2), vec!(RecordRange::new(25, 3)));
539 assert_eq!(r1 + RecordRange::new(26, 8), vec!(RecordRange::new(25, 9)));
540 assert_eq!(r1 + RecordRange::new(26, 0), vec!(RecordRange::new(25, 0)));
541 }
542}