starnix_core/vfs/
record_locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::task::{CurrentTask, WaitQueue, Waiter};
6use crate::vfs::{FdTableId, FileObject, FileObjectId};
7use starnix_sync::{Locked, Mutex, Unlocked};
8use starnix_uapi::errors::{EAGAIN, Errno};
9use starnix_uapi::{
10    __kernel_off_t, F_GETLK, F_GETLK64, F_OFD_GETLK, F_OFD_SETLK, F_OFD_SETLKW, F_RDLCK, F_SETLK,
11    F_SETLK64, F_SETLKW, F_SETLKW64, F_UNLCK, F_WRLCK, SEEK_CUR, SEEK_END, SEEK_SET, c_short,
12    errno, error, pid_t, uapi,
13};
14use std::collections::BTreeSet;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17enum RecordLength {
18    Value(usize),
19    Infinite,
20}
21
22impl RecordLength {
23    fn new(value: usize) -> Self {
24        if value == 0 { Self::Infinite } else { Self::Value(value) }
25    }
26    fn value(&self) -> __kernel_off_t {
27        match self {
28            Self::Value(e) => *e as __kernel_off_t,
29            Self::Infinite => 0,
30        }
31    }
32}
33
34impl std::ops::Add<usize> for RecordLength {
35    type Output = Self;
36
37    fn add(self, element: usize) -> Self {
38        match self {
39            Self::Value(e) => Self::Value(e.saturating_add(element)),
40            Self::Infinite => Self::Infinite,
41        }
42    }
43}
44
45impl std::ops::Sub<usize> for RecordLength {
46    type Output = Option<Self>;
47
48    fn sub(self, element: usize) -> Option<Self> {
49        match self {
50            Self::Value(e) if e > element => Some(Self::Value(e - element)),
51            Self::Infinite => Some(Self::Infinite),
52            _ => None,
53        }
54    }
55}
56
57impl std::cmp::PartialEq<usize> for RecordLength {
58    fn eq(&self, other: &usize) -> bool {
59        match self {
60            Self::Value(e) => e == other,
61            Self::Infinite => false,
62        }
63    }
64}
65
66impl std::cmp::PartialOrd<usize> for RecordLength {
67    fn partial_cmp(&self, other: &usize) -> Option<std::cmp::Ordering> {
68        match self {
69            Self::Value(e) => e.partial_cmp(other),
70            Self::Infinite => Some(std::cmp::Ordering::Greater),
71        }
72    }
73}
74
75impl Ord for RecordLength {
76    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
77        if self == other {
78            std::cmp::Ordering::Equal
79        } else {
80            match self {
81                Self::Value(e1) => match other {
82                    Self::Value(e2) => e1.cmp(e2),
83                    Self::Infinite => std::cmp::Ordering::Less,
84                },
85                Self::Infinite => std::cmp::Ordering::Greater,
86            }
87        }
88    }
89}
90
91impl PartialOrd for RecordLength {
92    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
93        Some(self.cmp(other))
94    }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
98struct RecordRange {
99    start: usize,
100    length: RecordLength,
101}
102
103impl RecordRange {
104    fn new(start: usize, length: usize) -> Self {
105        Self { start, length: RecordLength::new(length) }
106    }
107
108    /// Build a new `RecordRange` from the whence, start and length information in the flock
109    /// struct. The opened file is used when the position needs to be considered from the local
110    /// position of the file or the end of the file.
111    fn build(flock: &uapi::flock, file: &FileObject) -> Result<RecordRange, Errno> {
112        let origin: __kernel_off_t = match flock.l_whence as u32 {
113            SEEK_SET => 0,
114            SEEK_CUR => *file.offset.lock(),
115            SEEK_END => file.node().info().size.try_into().map_err(|_| errno!(EINVAL))?,
116            _ => {
117                return error!(EINVAL);
118            }
119        };
120        let mut start = origin.checked_add(flock.l_start).ok_or_else(|| errno!(EOVERFLOW))?;
121        let mut length = flock.l_len;
122        if length < 0 {
123            start = start.checked_add(length).ok_or_else(|| errno!(EINVAL))?;
124            length = length.checked_neg().ok_or_else(|| errno!(EINVAL))?;
125        }
126        if start < 0 {
127            return error!(EINVAL);
128        }
129        Ok(Self::new(start as usize, length as usize))
130    }
131
132    fn end(&self) -> RecordLength {
133        self.length + self.start
134    }
135
136    fn intersects(&self, other: &RecordRange) -> bool {
137        let r1 = std::cmp::min(self, other);
138        let r2 = std::cmp::max(self, other);
139        r1.end() > r2.start
140    }
141}
142
143impl std::ops::Sub<RecordRange> for RecordRange {
144    type Output = Vec<Self>;
145
146    fn sub(self, other: RecordRange) -> Vec<RecordRange> {
147        if !self.intersects(&other) {
148            return vec![self];
149        }
150        let mut vec = Vec::with_capacity(2);
151        if self.start < other.start {
152            let length = std::cmp::min(RecordLength::Value(other.start - self.start), self.length);
153            vec.push(RecordRange { start: self.start, length });
154        }
155        if let RecordLength::Value(start) = other.end() {
156            let end = self.end();
157            if let Some(length) = end - start {
158                vec.push(RecordRange { start, length });
159            }
160        }
161        vec
162    }
163}
164
165impl std::ops::Add<RecordRange> for RecordRange {
166    type Output = Vec<Self>;
167
168    fn add(self, other: RecordRange) -> Vec<RecordRange> {
169        let r1 = std::cmp::min(self, other);
170        let r2 = std::cmp::max(self, other);
171        let r1_end = r1.end();
172        if r1_end < r2.start {
173            vec![r1, r2]
174        } else {
175            let end = std::cmp::max(r1_end, r2.end());
176            vec![RecordRange {
177                start: r1.start,
178                length: (end - r1.start).expect("Length is guaranteed to exist"),
179            }]
180        }
181    }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
185enum RecordLockType {
186    Read,
187    Write,
188}
189
190impl RecordLockType {
191    fn build(flock: &uapi::flock) -> Result<Option<Self>, Errno> {
192        match flock.l_type as u32 {
193            F_UNLCK => Ok(None),
194            F_RDLCK => Ok(Some(Self::Read)),
195            F_WRLCK => Ok(Some(Self::Write)),
196            _ => error!(EINVAL),
197        }
198    }
199
200    /// Returns whether the current lock type is compatible with the other lock type. This only
201    /// happends when both locks are read locks.
202    fn is_compatible(&self, other: RecordLockType) -> bool {
203        *self == Self::Read && other == Self::Read
204    }
205
206    fn has_permission(&self, file: &FileObject) -> bool {
207        match self {
208            Self::Read => file.can_read(),
209            Self::Write => file.can_write(),
210        }
211    }
212
213    fn value(&self) -> c_short {
214        match self {
215            Self::Read => F_RDLCK as c_short,
216            Self::Write => F_WRLCK as c_short,
217        }
218    }
219}
220
221#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
222pub enum RecordLockOwner {
223    FdTable(FdTableId),
224    FileObject(FileObjectId),
225}
226
227impl RecordLockOwner {
228    fn new(current_task: &CurrentTask, cmd: RecordLockCommand, file: &FileObject) -> Self {
229        if cmd.is_ofd() {
230            RecordLockOwner::FileObject(file.id)
231        } else {
232            RecordLockOwner::FdTable(current_task.files.id())
233        }
234    }
235}
236
237#[derive(Debug, Clone)]
238struct RecordLock {
239    pub owner: RecordLockOwner,
240    pub range: RecordRange,
241    pub lock_type: RecordLockType,
242    pub process_id: pid_t,
243}
244
245impl RecordLock {
246    fn as_tuple(&self) -> (RecordLockOwner, &RecordRange, RecordLockType) {
247        (self.owner, &self.range, self.lock_type)
248    }
249}
250
251impl PartialEq for RecordLock {
252    fn eq(&self, other: &Self) -> bool {
253        self.as_tuple() == other.as_tuple()
254    }
255}
256
257impl Eq for RecordLock {}
258
259impl Ord for RecordLock {
260    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
261        self.as_tuple().cmp(&other.as_tuple())
262    }
263}
264
265impl PartialOrd for RecordLock {
266    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
267        Some(self.cmp(other))
268    }
269}
270#[derive(Default, Debug)]
271pub struct RecordLocksState {
272    locks: BTreeSet<RecordLock>,
273    queue: WaitQueue,
274}
275
276impl RecordLocksState {
277    /// Returns any lock that would conflict with a lock of type `lock_type` over `range` by
278    /// `fd_table`.
279    fn get_conflicting_lock(
280        &self,
281        owner: RecordLockOwner,
282        lock_type: RecordLockType,
283        range: &RecordRange,
284    ) -> Option<uapi::flock> {
285        for record in &self.locks {
286            if owner == record.owner {
287                continue;
288            }
289            if lock_type.is_compatible(record.lock_type) {
290                continue;
291            }
292            if range.intersects(&record.range) {
293                return Some(uapi::flock {
294                    l_type: record.lock_type.value(),
295                    l_whence: SEEK_SET as c_short,
296                    l_start: record.range.start as __kernel_off_t,
297                    l_len: record.range.length.value(),
298                    l_pid: record.process_id,
299                    ..Default::default()
300                });
301            }
302        }
303        None
304    }
305
306    fn apply_lock(
307        &mut self,
308        process_id: pid_t,
309        owner: RecordLockOwner,
310        lock_type: RecordLockType,
311        range: RecordRange,
312    ) -> Result<(), Errno> {
313        let mut owned_locks_in_range = Vec::new();
314        for lock in self.locks.iter().filter(|record| range.intersects(&record.range)) {
315            if owner == lock.owner {
316                owned_locks_in_range.push(lock.clone());
317            } else if !lock_type.is_compatible(lock.lock_type) {
318                // conflict
319                return error!(EAGAIN);
320            }
321        }
322        let mut new_lock = RecordLock { owner, range, lock_type, process_id };
323        for lock in owned_locks_in_range {
324            self.locks.remove(&lock);
325            if lock.lock_type == lock_type {
326                let new_ranges = new_lock.range + lock.range;
327                assert!(new_ranges.len() == 1);
328                new_lock.range = new_ranges[0];
329            } else {
330                for range in lock.range - new_lock.range {
331                    let mut remaining_lock = lock.clone();
332                    remaining_lock.range = range;
333                    self.locks.insert(remaining_lock);
334                }
335            }
336        }
337        self.locks.insert(new_lock);
338        self.queue.notify_all();
339        Ok(())
340    }
341
342    fn unlock(&mut self, owner: RecordLockOwner, range: RecordRange) -> Result<(), Errno> {
343        let intersection_locks: Vec<_> = self
344            .locks
345            .iter()
346            .filter(|record| owner == record.owner && range.intersects(&record.range))
347            .cloned()
348            .collect();
349        for lock in intersection_locks {
350            self.locks.remove(&lock);
351            for new_range in lock.range - range {
352                let mut new_lock = lock.clone();
353                new_lock.range = new_range;
354                self.locks.insert(new_lock);
355            }
356        }
357        self.queue.notify_all();
358        Ok(())
359    }
360
361    fn release_locks(&mut self, owner: RecordLockOwner) {
362        self.locks.retain(|lock| lock.owner != owner);
363        self.queue.notify_all();
364    }
365}
366
367#[allow(non_camel_case_types)]
368#[allow(clippy::upper_case_acronyms)]
369#[derive(Debug, Clone, Copy, PartialEq, Eq)]
370pub enum RecordLockCommand {
371    SETLK,
372    SETLKW,
373    GETLK,
374    OFD_GETLK,
375    OFD_SETLK,
376    OFD_SETLKW,
377}
378
379impl RecordLockCommand {
380    pub fn from_raw(cmd: u32) -> Option<Self> {
381        match cmd {
382            F_SETLK | F_SETLK64 => Some(RecordLockCommand::SETLK),
383            F_SETLKW | F_SETLKW64 => Some(RecordLockCommand::SETLKW),
384            F_GETLK | F_GETLK64 => Some(RecordLockCommand::GETLK),
385            F_OFD_GETLK => Some(RecordLockCommand::OFD_GETLK),
386            F_OFD_SETLK => Some(RecordLockCommand::OFD_SETLK),
387            F_OFD_SETLKW => Some(RecordLockCommand::OFD_SETLKW),
388            _ => None,
389        }
390    }
391
392    fn is_ofd(&self) -> bool {
393        match self {
394            RecordLockCommand::SETLK | RecordLockCommand::SETLKW | RecordLockCommand::GETLK => {
395                false
396            }
397            RecordLockCommand::OFD_GETLK
398            | RecordLockCommand::OFD_SETLK
399            | RecordLockCommand::OFD_SETLKW => true,
400        }
401    }
402
403    fn is_get(&self) -> bool {
404        *self == RecordLockCommand::GETLK || *self == RecordLockCommand::OFD_GETLK
405    }
406
407    fn is_blocking(&self) -> bool {
408        *self == RecordLockCommand::SETLKW || *self == RecordLockCommand::OFD_SETLKW
409    }
410}
411
412#[derive(Default, Debug)]
413pub struct RecordLocks {
414    state: Mutex<RecordLocksState>,
415}
416
417impl RecordLocks {
418    /// Apply the fcntl lock operation by the given `current_task`, on the given `file`.
419    ///
420    /// If this method succeed, and doesn't return None, the returned flock struct must be used to
421    /// overwrite the content of the flock struct passed by the user.
422    pub fn lock(
423        &self,
424        locked: &mut Locked<Unlocked>,
425        current_task: &CurrentTask,
426        file: &FileObject,
427        cmd: RecordLockCommand,
428        mut flock: uapi::flock,
429    ) -> Result<Option<uapi::flock>, Errno> {
430        if cmd.is_ofd() && flock.l_pid != 0 {
431            return error!(EINVAL);
432        }
433        let owner: RecordLockOwner = RecordLockOwner::new(current_task, cmd, file);
434        let lock_type = RecordLockType::build(&flock)?;
435        let range = RecordRange::build(&flock, file)?;
436        if cmd.is_get() {
437            let lock_type = lock_type.ok_or_else(|| errno!(EINVAL))?;
438            Ok(self.state.lock().get_conflicting_lock(owner, lock_type, &range).or_else(|| {
439                flock.l_type = F_UNLCK as c_short;
440                Some(flock)
441            }))
442        } else {
443            match lock_type {
444                Some(lock_type) => {
445                    if !lock_type.has_permission(file) {
446                        return error!(EBADF);
447                    }
448                    let blocking = cmd.is_blocking();
449                    loop {
450                        let mut state = self.state.lock();
451                        let waiter = blocking.then(|| {
452                            let waiter = Waiter::new();
453                            state.queue.wait_async(&waiter);
454                            waiter
455                        });
456                        let process_id =
457                            if cmd.is_ofd() { -1 } else { current_task.thread_group().leader };
458                        match state.apply_lock(process_id, owner, lock_type, range) {
459                            Err(errno) if blocking && errno == EAGAIN => {
460                                // TODO(qsr): Check deadlocks.
461                                if let Some(waiter) = waiter {
462                                    std::mem::drop(state);
463                                    waiter.wait(locked, current_task)?;
464                                }
465                            }
466                            result => return result.map(|_| None),
467                        }
468                    }
469                }
470                None => {
471                    self.state.lock().unlock(owner, range)?;
472                }
473            }
474            Ok(None)
475        }
476    }
477
478    pub fn release_locks(&self, owner: RecordLockOwner) {
479        self.state.lock().release_locks(owner);
480    }
481}
482
483#[cfg(test)]
484mod test {
485    use super::*;
486
487    #[::fuchsia::test]
488    fn test_range_intersects() {
489        let r1 = RecordRange::new(25, 3);
490        assert!(r1.intersects(&RecordRange::new(25, 1)));
491        assert!(r1.intersects(&RecordRange::new(0, 0)));
492        assert!(r1.intersects(&RecordRange::new(0, 60)));
493        assert!(r1.intersects(&RecordRange::new(27, 8)));
494        assert!(!r1.intersects(&RecordRange::new(28, 1)));
495        assert!(!r1.intersects(&RecordRange::new(29, 8)));
496        assert!(!r1.intersects(&RecordRange::new(29, 0)));
497        assert!(!r1.intersects(&RecordRange::new(0, 8)));
498    }
499
500    #[::fuchsia::test]
501    fn test_range_sub() {
502        let r1 = RecordRange::new(25, 3);
503        assert_eq!(r1 - RecordRange::new(0, 2), vec!(r1));
504        assert_eq!(r1 - RecordRange::new(29, 2), vec!(r1));
505        assert_eq!(r1 - RecordRange::new(29, 0), vec!(r1));
506        assert_eq!(r1 - RecordRange::new(20, 0), vec!());
507        assert_eq!(r1 - RecordRange::new(20, 12), vec!());
508        assert_eq!(r1 - RecordRange::new(20, 6), vec!(RecordRange::new(26, 2)));
509        assert_eq!(r1 - RecordRange::new(26, 3), vec!(RecordRange::new(25, 1)));
510        assert_eq!(r1 - RecordRange::new(26, 0), vec!(RecordRange::new(25, 1)));
511        assert_eq!(
512            r1 - RecordRange::new(26, 1),
513            vec!(RecordRange::new(25, 1), RecordRange::new(27, 1))
514        );
515
516        let r2 = RecordRange::new(25, 0);
517        assert_eq!(r2 - RecordRange::new(0, 2), vec!(r2));
518        assert_eq!(r2 - RecordRange::new(20, 0), vec!());
519        assert_eq!(r2 - RecordRange::new(20, 6), vec!(RecordRange::new(26, 0)));
520        assert_eq!(r2 - RecordRange::new(26, 0), vec!(RecordRange::new(25, 1)));
521        assert_eq!(
522            r2 - RecordRange::new(26, 1),
523            vec!(RecordRange::new(25, 1), RecordRange::new(27, 0))
524        );
525    }
526
527    #[::fuchsia::test]
528    fn test_range_add() {
529        let r1 = RecordRange::new(25, 3);
530        assert_eq!(r1 + RecordRange::new(0, 2), vec!(RecordRange::new(0, 2), r1));
531        assert_eq!(r1 + RecordRange::new(30, 2), vec!(r1, RecordRange::new(30, 2)));
532        assert_eq!(r1 + RecordRange::new(30, 0), vec!(r1, RecordRange::new(30, 0)));
533        assert_eq!(r1 + RecordRange::new(22, 3), vec!(RecordRange::new(22, 6)));
534        assert_eq!(r1 + RecordRange::new(22, 4), vec!(RecordRange::new(22, 6)));
535        assert_eq!(r1 + RecordRange::new(22, 8), vec!(RecordRange::new(22, 8)));
536        assert_eq!(r1 + RecordRange::new(22, 0), vec!(RecordRange::new(22, 0)));
537        assert_eq!(r1 + RecordRange::new(26, 1), vec!(RecordRange::new(25, 3)));
538        assert_eq!(r1 + RecordRange::new(26, 2), vec!(RecordRange::new(25, 3)));
539        assert_eq!(r1 + RecordRange::new(26, 8), vec!(RecordRange::new(25, 9)));
540        assert_eq!(r1 + RecordRange::new(26, 0), vec!(RecordRange::new(25, 0)));
541    }
542}