1use crate::mm::{MemoryManager, PAGE_SIZE};
6use bitflags::bitflags;
7use range_map::RangeMap;
8use starnix_logging::track_stub;
9use starnix_sync::{LockBefore, Locked, OrderedMutex, UserFaultInner};
10use starnix_uapi::errors::Errno;
11use starnix_uapi::user_address::UserAddress;
12use starnix_uapi::{
13 UFFD_FEATURE_EVENT_FORK, UFFD_FEATURE_EVENT_REMAP, UFFD_FEATURE_EVENT_REMOVE,
14 UFFD_FEATURE_EVENT_UNMAP, UFFD_FEATURE_MINOR_HUGETLBFS, UFFD_FEATURE_MINOR_SHMEM,
15 UFFD_FEATURE_MISSING_HUGETLBFS, UFFD_FEATURE_MISSING_SHMEM, UFFD_FEATURE_SIGBUS,
16 UFFD_FEATURE_THREAD_ID, UFFDIO_CONTINUE_MODE_DONTWAKE, UFFDIO_COPY_MODE_DONTWAKE,
17 UFFDIO_COPY_MODE_WP, UFFDIO_REGISTER_MODE_MINOR, UFFDIO_REGISTER_MODE_MISSING,
18 UFFDIO_REGISTER_MODE_WP, UFFDIO_ZEROPAGE_MODE_DONTWAKE, errno, error,
19};
20use std::ops::Range;
21use std::sync::{Arc, Weak};
22
23#[derive(Debug)]
24pub struct UserFault {
25 mm: Weak<MemoryManager>,
26 state: OrderedMutex<UserFaultState, UserFaultInner>,
27}
28
29#[derive(Debug, Clone)]
30struct UserFaultState {
31 features: Option<UserFaultFeatures>,
33
34 userfault_pages: RangeMap<UserAddress, bool>,
37}
38
39impl UserFault {
40 pub fn new(mm: Weak<MemoryManager>) -> Self {
41 Self { mm, state: OrderedMutex::new(UserFaultState::new()) }
42 }
43
44 pub fn insert_pages<L>(&self, locked: &mut Locked<L>, range: Range<UserAddress>, value: bool)
45 where
46 L: LockBefore<UserFaultInner>,
47 {
48 let _ = self.state.lock(locked).userfault_pages.insert(range, value);
50 }
51
52 pub fn remove_pages<L>(&self, locked: &mut Locked<L>, range: Range<UserAddress>) -> bool
53 where
54 L: LockBefore<UserFaultInner>,
55 {
56 !self.state.lock(locked).userfault_pages.remove(range).is_empty()
57 }
58
59 pub fn get_registered_pages_overlapping_range<L>(
60 &self,
61 locked: &mut Locked<L>,
62 range: Range<UserAddress>,
63 ) -> Vec<Range<UserAddress>>
64 where
65 L: LockBefore<UserFaultInner>,
66 {
67 self.state.lock(locked).userfault_pages.get_keys(range).cloned().collect()
68 }
69
70 pub fn contains_addr<L>(&self, locked: &mut Locked<L>, addr: UserAddress) -> bool
71 where
72 L: LockBefore<UserFaultInner>,
73 {
74 self.state.lock(locked).userfault_pages.get(addr).is_some()
75 }
76
77 pub fn get_first_populated_page_after<L>(
78 &self,
79 locked: &mut Locked<L>,
80 addr: UserAddress,
81 ) -> Option<UserAddress>
82 where
83 L: LockBefore<UserFaultInner>,
84 {
85 self.state.lock(locked).userfault_pages.get(addr).map(|(affected_range, is_populated)| {
86 if *is_populated { addr } else { affected_range.end }
87 })
88 }
89
90 pub fn is_initialized<L>(self: &Arc<Self>, locked: &mut Locked<L>) -> bool
91 where
92 L: LockBefore<UserFaultInner>,
93 {
94 self.state.lock(locked).features.is_some()
95 }
96
97 pub fn has_features<L>(
98 self: &Arc<Self>,
99 locked: &mut Locked<L>,
100 features: UserFaultFeatures,
101 ) -> bool
102 where
103 L: LockBefore<UserFaultInner>,
104 {
105 self.state.lock(locked).features.map(|f| f.contains(features)).unwrap_or(false)
106 }
107
108 pub fn initialize<L>(self: &Arc<Self>, locked: &mut Locked<L>, features: UserFaultFeatures)
109 where
110 L: LockBefore<UserFaultInner>,
111 {
112 self.state.lock(locked).features = Some(features);
113 }
114
115 pub fn op_register<L>(
116 self: &Arc<Self>,
117 locked: &mut Locked<L>,
118 start: UserAddress,
119 len: u64,
120 mode: FaultRegisterMode,
121 ) -> Result<SupportedUserFaultIoctls, Errno>
122 where
123 L: LockBefore<UserFaultInner>,
124 {
125 if !self.is_initialized(locked) {
126 return error!(EINVAL);
127 }
128 if !self.has_features(locked, UserFaultFeatures::SIGBUS) {
129 track_stub!(TODO("https://fxbug.dev/391599171"), "userfault without SIGBUS feature");
130 return error!(ENOTSUP);
131 }
132 check_op_range(start, len)?;
133 let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
134
135 mm.register_with_uffd(locked, start, len as usize, self, mode)?;
136 Ok(SupportedUserFaultIoctls::COPY | SupportedUserFaultIoctls::ZERO_PAGE)
137 }
138
139 pub fn op_unregister<L>(
140 self: &Arc<Self>,
141 locked: &mut Locked<L>,
142 start: UserAddress,
143 len: u64,
144 ) -> Result<(), Errno>
145 where
146 L: LockBefore<UserFaultInner>,
147 {
148 if !self.is_initialized(locked) {
149 return error!(EINVAL);
150 }
151 check_op_range(start, len)?;
152 let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
153 mm.unregister_range_from_uffd(locked, self, start, len as usize)
154 }
155
156 pub fn op_copy<L>(
157 self: &Arc<Self>,
158 locked: &mut Locked<L>,
159 mm_source: &MemoryManager,
160 source: UserAddress,
161 dest: UserAddress,
162 len: u64,
163 _mode: FaultCopyMode,
164 ) -> Result<usize, Errno>
165 where
166 L: LockBefore<UserFaultInner>,
167 {
168 if !self.is_initialized(locked) {
169 return error!(EINVAL);
170 }
171 check_op_range(source, len)?;
172 check_op_range(dest, len)?;
173 let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
174
175 if Arc::as_ptr(&mm) == mm_source as *const MemoryManager {
178 mm.copy_from_uffd(locked, source, dest, len as usize, self)
179 } else {
180 let mut buf = vec![std::mem::MaybeUninit::uninit(); len as usize];
181 let buf = mm_source.syscall_read_memory(source, &mut buf)?;
182 mm.fill_from_uffd(locked, dest, buf, len as usize, self)
183 }
184 }
185
186 pub fn op_zero<L>(
187 self: &Arc<Self>,
188 locked: &mut Locked<L>,
189 start: UserAddress,
190 len: u64,
191 _mode: FaultZeroMode,
192 ) -> Result<usize, Errno>
193 where
194 L: LockBefore<UserFaultInner>,
195 {
196 if !self.is_initialized(locked) {
197 return error!(EINVAL);
198 }
199 check_op_range(start, len)?;
200 let mm = self.mm.upgrade().ok_or_else(|| errno!(EINVAL))?;
201 mm.zero_from_uffd(locked, start, len as usize, self)
202 }
203
204 pub fn cleanup<L>(self: &Arc<Self>, locked: &mut Locked<L>)
205 where
206 L: LockBefore<UserFaultInner>,
207 {
208 if let Some(mm) = self.mm.upgrade() {
209 mm.unregister_uffd(locked, self);
210 }
211 }
212}
213
214impl UserFaultState {
215 pub fn new() -> Self {
216 Self { features: None, userfault_pages: RangeMap::default() }
217 }
218}
219
220bitflags! {
221 #[derive(Debug, Clone, Copy, Eq, PartialEq)]
222 pub struct UserFaultFeatures: u32 {
223 const ALL_SUPPORTED = UFFD_FEATURE_SIGBUS;
224 const EVENT_FORK = UFFD_FEATURE_EVENT_FORK;
225 const EVENT_REMAP = UFFD_FEATURE_EVENT_REMAP;
226 const EVENT_REMOVE = UFFD_FEATURE_EVENT_REMOVE;
227 const EVENT_UNMAP = UFFD_FEATURE_EVENT_UNMAP;
228 const MISSING_HUGETLBFS = UFFD_FEATURE_MISSING_HUGETLBFS;
229 const MISSING_SHMEM = UFFD_FEATURE_MISSING_SHMEM;
230 const SIGBUS = UFFD_FEATURE_SIGBUS;
231 const THREAD_ID = UFFD_FEATURE_THREAD_ID;
232 const MINOR_HUGETLBFS = UFFD_FEATURE_MINOR_HUGETLBFS;
233 const MINOR_SHMEM = UFFD_FEATURE_MINOR_SHMEM;
234 }
235
236 #[derive(Debug, Clone, Copy, Eq, PartialEq)]
237 pub struct FaultRegisterMode: u32 {
238 const MINOR = UFFDIO_REGISTER_MODE_MINOR;
239 const MISSING = UFFDIO_REGISTER_MODE_MISSING;
240 const WRITE_PROTECT = UFFDIO_REGISTER_MODE_WP;
241 }
242
243 pub struct FaultCopyMode: u32 {
244 const DONT_WAKE = UFFDIO_COPY_MODE_DONTWAKE;
245 const WRITE_PROTECT = UFFDIO_COPY_MODE_WP;
246 }
247
248 pub struct FaultZeroMode: u32 {
249 const DONT_WAKE = UFFDIO_ZEROPAGE_MODE_DONTWAKE;
250 }
251
252 pub struct FaultContinueMode: u32 {
253 const DONT_WAKE = UFFDIO_CONTINUE_MODE_DONTWAKE;
254 }
255
256
257 pub struct SupportedUserFaultIoctls: u64 {
258 const COPY = 1 << starnix_uapi::_UFFDIO_COPY;
259 const WAKE = 1 << starnix_uapi::_UFFDIO_WAKE;
260 const WRITE_PROTECT = 1 << starnix_uapi::_UFFDIO_WRITEPROTECT;
261 const ZERO_PAGE = 1 << starnix_uapi::_UFFDIO_ZEROPAGE;
262 const CONTINUE = 1 << starnix_uapi::_UFFDIO_CONTINUE;
263 }
264}
265
266fn check_op_range(addr: UserAddress, len: u64) -> Result<(), Errno> {
267 if addr.is_aligned(*PAGE_SIZE) && len % *PAGE_SIZE == 0 && len > 0 {
268 Ok(())
269 } else {
270 error!(EINVAL)
271 }
272}