starnix_core/mm/
userfault.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::mm::{MemoryManager, PAGE_SIZE};
6use bitflags::bitflags;
7use range_map::RangeMap;
8use starnix_logging::track_stub;
9use starnix_sync::{LockBefore, Locked, OrderedMutex, UserFaultInner};
10use starnix_uapi::errors::Errno;
11use starnix_uapi::user_address::UserAddress;
12use starnix_uapi::{
13    UFFD_FEATURE_EVENT_FORK, UFFD_FEATURE_EVENT_REMAP, UFFD_FEATURE_EVENT_REMOVE,
14    UFFD_FEATURE_EVENT_UNMAP, UFFD_FEATURE_MINOR_HUGETLBFS, UFFD_FEATURE_MINOR_SHMEM,
15    UFFD_FEATURE_MISSING_HUGETLBFS, UFFD_FEATURE_MISSING_SHMEM, UFFD_FEATURE_SIGBUS,
16    UFFD_FEATURE_THREAD_ID, UFFDIO_CONTINUE_MODE_DONTWAKE, UFFDIO_COPY_MODE_DONTWAKE,
17    UFFDIO_COPY_MODE_WP, UFFDIO_REGISTER_MODE_MINOR, UFFDIO_REGISTER_MODE_MISSING,
18    UFFDIO_REGISTER_MODE_WP, UFFDIO_ZEROPAGE_MODE_DONTWAKE, errno, error,
19};
20use std::ops::Range;
21use std::sync::{Arc, Weak};
22
23#[derive(Debug)]
24pub struct UserFault {
25    mm: Weak<MemoryManager>,
26    state: OrderedMutex<UserFaultState, UserFaultInner>,
27}
28
29#[derive(Debug, Clone)]
30struct UserFaultState {
31    /// If initialized, contains features that this userfault was initialized with
32    features: Option<UserFaultFeatures>,
33
34    /// Pages that are currently registered with this userfault object, and whether they are
35    /// already populated.
36    userfault_pages: RangeMap<UserAddress, bool>,
37}
38
39impl UserFault {
40    pub fn new(mm: Weak<MemoryManager>) -> Self {
41        Self { mm, state: OrderedMutex::new(UserFaultState::new()) }
42    }
43
44    pub fn insert_pages<L>(&self, locked: &mut Locked<L>, range: Range<UserAddress>, value: bool)
45    where
46        L: LockBefore<UserFaultInner>,
47    {
48        // RangeMap uses #[must_use] for its default usecase but this drop is trivial.
49        let _ = self.state.lock(locked).userfault_pages.insert(range, value);
50    }
51
52    pub fn remove_pages<L>(&self, locked: &mut Locked<L>, range: Range<UserAddress>) -> bool
53    where
54        L: LockBefore<UserFaultInner>,
55    {
56        !self.state.lock(locked).userfault_pages.remove(range).is_empty()
57    }
58
59    pub fn get_registered_pages_overlapping_range<L>(
60        &self,
61        locked: &mut Locked<L>,
62        range: Range<UserAddress>,
63    ) -> Vec<Range<UserAddress>>
64    where
65        L: LockBefore<UserFaultInner>,
66    {
67        self.state.lock(locked).userfault_pages.get_keys(range).cloned().collect()
68    }
69
70    pub fn contains_addr<L>(&self, locked: &mut Locked<L>, addr: UserAddress) -> bool
71    where
72        L: LockBefore<UserFaultInner>,
73    {
74        self.state.lock(locked).userfault_pages.get(addr).is_some()
75    }
76
77    pub fn get_first_populated_page_after<L>(
78        &self,
79        locked: &mut Locked<L>,
80        addr: UserAddress,
81    ) -> Option<UserAddress>
82    where
83        L: LockBefore<UserFaultInner>,
84    {
85        self.state.lock(locked).userfault_pages.get(addr).map(|(affected_range, is_populated)| {
86            if *is_populated { addr } else { affected_range.end }
87        })
88    }
89
90    pub fn is_initialized<L>(self: &Arc<Self>, locked: &mut Locked<L>) -> bool
91    where
92        L: LockBefore<UserFaultInner>,
93    {
94        self.state.lock(locked).features.is_some()
95    }
96
97    pub fn has_features<L>(
98        self: &Arc<Self>,
99        locked: &mut Locked<L>,
100        features: UserFaultFeatures,
101    ) -> bool
102    where
103        L: LockBefore<UserFaultInner>,
104    {
105        self.state.lock(locked).features.map(|f| f.contains(features)).unwrap_or(false)
106    }
107
108    pub fn initialize<L>(self: &Arc<Self>, locked: &mut Locked<L>, features: UserFaultFeatures)
109    where
110        L: LockBefore<UserFaultInner>,
111    {
112        self.state.lock(locked).features = Some(features);
113    }
114
115    pub fn op_register<L>(
116        self: &Arc<Self>,
117        locked: &mut Locked<L>,
118        start: UserAddress,
119        len: u64,
120        mode: FaultRegisterMode,
121    ) -> Result<SupportedUserFaultIoctls, Errno>
122    where
123        L: LockBefore<UserFaultInner>,
124    {
125        if !self.is_initialized(locked) {
126            return error!(EINVAL);
127        }
128        if !self.has_features(locked, UserFaultFeatures::SIGBUS) {
129            track_stub!(TODO("https://fxbug.dev/391599171"), "userfault without SIGBUS feature");
130            return error!(ENOTSUP);
131        }
132        check_op_range(start, len)?;
133        let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
134
135        mm.register_with_uffd(locked, start, len as usize, self, mode)?;
136        Ok(SupportedUserFaultIoctls::COPY | SupportedUserFaultIoctls::ZERO_PAGE)
137    }
138
139    pub fn op_unregister<L>(
140        self: &Arc<Self>,
141        locked: &mut Locked<L>,
142        start: UserAddress,
143        len: u64,
144    ) -> Result<(), Errno>
145    where
146        L: LockBefore<UserFaultInner>,
147    {
148        if !self.is_initialized(locked) {
149            return error!(EINVAL);
150        }
151        check_op_range(start, len)?;
152        let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
153        mm.unregister_range_from_uffd(locked, self, start, len as usize)
154    }
155
156    pub fn op_copy<L>(
157        self: &Arc<Self>,
158        locked: &mut Locked<L>,
159        mm_source: &MemoryManager,
160        source: UserAddress,
161        dest: UserAddress,
162        len: u64,
163        _mode: FaultCopyMode,
164    ) -> Result<usize, Errno>
165    where
166        L: LockBefore<UserFaultInner>,
167    {
168        if !self.is_initialized(locked) {
169            return error!(EINVAL);
170        }
171        check_op_range(source, len)?;
172        check_op_range(dest, len)?;
173        let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
174
175        // If the copy happens inside the same process, do it inside this process' memory manager
176        // so that the lock is held throughout the operation.
177        if Arc::as_ptr(&mm) == mm_source as *const MemoryManager {
178            mm.copy_from_uffd(locked, source, dest, len as usize, self)
179        } else {
180            let mut buf = vec![std::mem::MaybeUninit::uninit(); len as usize];
181            let buf = mm_source.syscall_read_memory(source, &mut buf)?;
182            mm.fill_from_uffd(locked, dest, buf, len as usize, self)
183        }
184    }
185
186    pub fn op_zero<L>(
187        self: &Arc<Self>,
188        locked: &mut Locked<L>,
189        start: UserAddress,
190        len: u64,
191        _mode: FaultZeroMode,
192    ) -> Result<usize, Errno>
193    where
194        L: LockBefore<UserFaultInner>,
195    {
196        if !self.is_initialized(locked) {
197            return error!(EINVAL);
198        }
199        check_op_range(start, len)?;
200        let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
201        mm.zero_from_uffd(locked, start, len as usize, self)
202    }
203
204    pub fn cleanup<L>(self: &Arc<Self>, locked: &mut Locked<L>)
205    where
206        L: LockBefore<UserFaultInner>,
207    {
208        if let Some(mm) = self.mm.upgrade() {
209            mm.unregister_uffd(locked, self);
210        }
211    }
212}
213
214impl UserFaultState {
215    pub fn new() -> Self {
216        Self { features: None, userfault_pages: RangeMap::default() }
217    }
218}
219
220bitflags! {
221    #[derive(Debug, Clone, Copy, Eq, PartialEq)]
222    pub struct UserFaultFeatures: u32 {
223        const ALL_SUPPORTED = UFFD_FEATURE_SIGBUS;
224        const EVENT_FORK = UFFD_FEATURE_EVENT_FORK;
225        const EVENT_REMAP = UFFD_FEATURE_EVENT_REMAP;
226        const EVENT_REMOVE = UFFD_FEATURE_EVENT_REMOVE;
227        const EVENT_UNMAP = UFFD_FEATURE_EVENT_UNMAP;
228        const MISSING_HUGETLBFS = UFFD_FEATURE_MISSING_HUGETLBFS;
229        const MISSING_SHMEM = UFFD_FEATURE_MISSING_SHMEM;
230        const SIGBUS = UFFD_FEATURE_SIGBUS;
231        const THREAD_ID = UFFD_FEATURE_THREAD_ID;
232        const MINOR_HUGETLBFS = UFFD_FEATURE_MINOR_HUGETLBFS;
233        const MINOR_SHMEM = UFFD_FEATURE_MINOR_SHMEM;
234    }
235
236    #[derive(Debug, Clone, Copy, Eq, PartialEq)]
237    pub struct FaultRegisterMode: u32 {
238        const MINOR = UFFDIO_REGISTER_MODE_MINOR;
239        const MISSING = UFFDIO_REGISTER_MODE_MISSING;
240        const WRITE_PROTECT = UFFDIO_REGISTER_MODE_WP;
241    }
242
243    pub struct FaultCopyMode: u32 {
244        const DONT_WAKE = UFFDIO_COPY_MODE_DONTWAKE;
245        const WRITE_PROTECT = UFFDIO_COPY_MODE_WP;
246    }
247
248    pub struct FaultZeroMode: u32 {
249        const DONT_WAKE = UFFDIO_ZEROPAGE_MODE_DONTWAKE;
250    }
251
252    pub struct FaultContinueMode: u32 {
253        const DONT_WAKE = UFFDIO_CONTINUE_MODE_DONTWAKE;
254    }
255
256
257    pub struct SupportedUserFaultIoctls: u64 {
258        const COPY = 1 << starnix_uapi::_UFFDIO_COPY;
259        const WAKE = 1 << starnix_uapi::_UFFDIO_WAKE;
260        const WRITE_PROTECT = 1 << starnix_uapi::_UFFDIO_WRITEPROTECT;
261        const ZERO_PAGE = 1 << starnix_uapi::_UFFDIO_ZEROPAGE;
262        const CONTINUE = 1 << starnix_uapi::_UFFDIO_CONTINUE;
263    }
264}
265
266fn check_op_range(addr: UserAddress, len: u64) -> Result<(), Errno> {
267    if addr.is_aligned(*PAGE_SIZE) && len % *PAGE_SIZE == 0 && len > 0 {
268        Ok(())
269    } else {
270        error!(EINVAL)
271    }
272}