1use crate::mm::{
6 DesiredAddress, IOVecPtr, MappingName, MappingOptions, MemoryAccessorExt, ProtectionFlags,
7 RemoteMemoryManager, TaskMemoryAccessor,
8};
9use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
10use crate::task::{CurrentTask, SimpleWaiter, WaitQueue};
11use crate::vfs::eventfd::EventFdFileObject;
12use crate::vfs::syscalls::IocbPtr;
13use crate::vfs::{
14 FdNumber, FileHandle, InputBuffer, OutputBuffer, UserBuffersInputBuffer,
15 UserBuffersOutputBuffer, VecInputBuffer, VecOutputBuffer, WeakFileHandle,
16 checked_add_offset_and_length,
17};
18use smallvec::smallvec;
19use starnix_logging::track_stub;
20use starnix_sync::{InterruptibleEvent, Locked, Mutex, Unlocked};
21use starnix_syscalls::SyscallResult;
22use starnix_types::user_buffer::{UserBuffer, UserBuffers};
23use starnix_uapi::errors::{EINTR, ETIMEDOUT, Errno};
24use starnix_uapi::{
25 IOCB_CMD_PREAD, IOCB_CMD_PREADV, IOCB_CMD_PWRITE, IOCB_CMD_PWRITEV, IOCB_FLAG_RESFD,
26 aio_context_t, errno, error, io_event, iocb,
27};
28use std::collections::VecDeque;
29use std::sync::Arc;
30use zerocopy::IntoBytes;
31
32const AIO_RING_SIZE: usize = 32;
34
35pub struct AioContext {
38 inner: Arc<AioContextInner>,
39}
40
41impl AioContext {
42 pub fn create(
43 current_task: &CurrentTask,
44 max_operations: usize,
45 ) -> Result<aio_context_t, Errno> {
46 let context = Arc::new(AioContext { inner: AioContextInner::new(max_operations) });
47 context.inner.spawn_worker_for(current_task, WorkerType::Read);
48 context.inner.spawn_worker_for(current_task, WorkerType::Write);
49 let context_addr = current_task.mm()?.map_anonymous(
50 DesiredAddress::Any,
51 AIO_RING_SIZE,
52 ProtectionFlags::READ | ProtectionFlags::WRITE,
53 MappingOptions::ANONYMOUS | MappingOptions::DONT_EXPAND,
54 MappingName::AioContext(context),
55 )?;
56 Ok(context_addr.ptr() as aio_context_t)
57 }
58
59 pub fn get_events(
60 &self,
61 current_task: &CurrentTask,
62 min_results: usize,
63 max_results: usize,
64 deadline: zx::MonotonicInstant,
65 ) -> Result<Vec<io_event>, Errno> {
66 self.inner.get_events(current_task, min_results, max_results, deadline)
67 }
68
69 pub fn submit(
70 self: &Arc<Self>,
71 current_task: &CurrentTask,
72 control_block: iocb,
73 iocb_addr: IocbPtr,
74 ) -> Result<(), Errno> {
75 self.inner.submit(current_task, control_block, iocb_addr)
76 }
77
78 pub fn cancel(
79 self: &Arc<Self>,
80 _current_task: &CurrentTask,
81 control_block: iocb,
82 iocb_addr: IocbPtr,
83 ) -> Result<(), Errno> {
84 self.inner.cancel(control_block, iocb_addr)
85 }
86}
87
88impl std::fmt::Debug for AioContext {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("AioContext").finish()
91 }
92}
93
94impl std::cmp::PartialEq for AioContext {
95 fn eq(&self, other: &AioContext) -> bool {
96 Arc::ptr_eq(&self.inner, &other.inner)
97 }
98}
99
100impl std::cmp::Eq for AioContext {}
101
102impl Drop for AioContext {
103 fn drop(&mut self) {
104 self.inner.stop();
105 }
106}
107
108struct AioContextInner {
109 operations: OperationQueue,
110 results: ResultQueue,
111}
112
113impl AioContextInner {
114 fn new(max_operations: usize) -> Arc<Self> {
115 Arc::new(Self {
116 operations: OperationQueue::new(max_operations),
117 results: Default::default(),
118 })
119 }
120
121 fn stop(&self) {
122 self.operations.stop();
123 }
124
125 fn get_events(
126 &self,
127 current_task: &CurrentTask,
128 min_results: usize,
129 max_results: usize,
130 deadline: zx::MonotonicInstant,
131 ) -> Result<Vec<io_event>, Errno> {
132 let mut events = self.results.dequeue(max_results);
133 if events.len() >= min_results {
134 return Ok(events);
135 }
136 let event = InterruptibleEvent::new();
137 loop {
138 let (mut waiter, guard) = SimpleWaiter::new(&event);
139 self.results.waiters.wait_async_simple(&mut waiter);
140 events.extend(self.results.dequeue(max_results - events.len()));
141 if events.len() >= min_results {
142 return Ok(events);
143 }
144 match current_task.block_until(guard, deadline) {
145 Err(err) if err == ETIMEDOUT => {
146 return Ok(events);
147 }
148 Err(err) if err == EINTR => {
149 if events.is_empty() {
150 Err(err)
151 } else {
152 return Ok(events);
153 }
154 }
155 result => result,
156 }?;
157 }
158 }
159
160 fn submit(
161 self: &Arc<Self>,
162 current_task: &CurrentTask,
163 control_block: iocb,
164 iocb_addr: IocbPtr,
165 ) -> Result<(), Errno> {
166 let op = IoOperation::new(current_task, control_block, iocb_addr)?;
167 self.operations.enqueue(op)
168 }
169
170 fn cancel(self: &Arc<Self>, control_block: iocb, iocb_addr: IocbPtr) -> Result<(), Errno> {
171 let op_type: OpType = (control_block.aio_lio_opcode as u32).try_into()?;
172 self.operations.remove(op_type.worker_type(), iocb_addr)
173 }
174
175 fn spawn_worker_for(self: &Arc<Self>, current_task: &CurrentTask, worker_type: WorkerType) {
176 let creds = current_task.current_creds().clone();
177 let inner = self.clone();
178 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
179 current_task.override_creds(creds, || {
180 inner.perform_next_action(locked, current_task, worker_type)
181 })
182 };
183 let req = SpawnRequestBuilder::new()
184 .with_debug_name("aio-worker")
185 .with_sync_closure(closure)
186 .build();
187 current_task.kernel().kthreads.spawner().spawn_from_request(req);
188 }
189
190 fn perform_next_action(
191 &self,
192 locked: &mut Locked<Unlocked>,
193 current_task: &CurrentTask,
194 worker_type: WorkerType,
195 ) {
196 while let Ok(IoAction::Op(op)) =
197 self.operations.block_until_dequeue(current_task, worker_type)
198 {
199 let Some(result) = op.execute(locked, current_task) else {
200 return;
201 };
202 self.results.enqueue(op.complete(result));
203
204 if let Some(eventfd) = op.eventfd {
205 if let Some(eventfd) = eventfd.upgrade() {
206 let mut input_buffer = VecInputBuffer::new(1u64.as_bytes());
207 let _ = eventfd.write(locked, current_task, &mut input_buffer);
208 }
209 }
210 }
211 }
212}
213
214#[derive(Debug, Clone, Copy)]
215enum WorkerType {
216 Read,
217 Write,
218}
219
220#[derive(Debug, Clone, Copy)]
221enum OpType {
222 PRead,
223 PReadV,
224 PWrite,
229 PWriteV,
230}
231
232impl OpType {
233 fn worker_type(self) -> WorkerType {
234 match self {
235 OpType::PRead | OpType::PReadV => WorkerType::Read,
236 OpType::PWrite | OpType::PWriteV => WorkerType::Write,
237 }
238 }
239}
240
241impl TryFrom<u32> for OpType {
242 type Error = Errno;
243
244 fn try_from(opcode: u32) -> Result<Self, Self::Error> {
245 match opcode {
246 IOCB_CMD_PREAD => Ok(Self::PRead),
247 IOCB_CMD_PREADV => Ok(Self::PReadV),
248 IOCB_CMD_PWRITE => Ok(Self::PWrite),
249 IOCB_CMD_PWRITEV => Ok(Self::PWriteV),
250 _ => {
251 track_stub!(TODO("https://fxbug.dev/297433877"), "io_submit opcode", opcode);
252 return error!(ENOSYS);
253 }
254 }
255 }
256}
257struct IoOperation {
258 op_type: OpType,
259 file: WeakFileHandle,
260 mm: RemoteMemoryManager,
261 buffers: UserBuffers,
262 offset: usize,
263 id: u64,
264 iocb_addr: IocbPtr,
265 eventfd: Option<WeakFileHandle>,
266}
267
268impl IoOperation {
269 fn new(
270 current_task: &CurrentTask,
271 control_block: iocb,
272 iocb_addr: IocbPtr,
273 ) -> Result<Self, Errno> {
274 if control_block.aio_reserved2 != 0 {
275 return error!(EINVAL);
276 }
277 let file =
278 current_task.get_file(FdNumber::from_raw(control_block.aio_fildes as i32))?;
279 let op_type = (control_block.aio_lio_opcode as u32).try_into()?;
280 let offset = control_block.aio_offset.try_into().map_err(|_| errno!(EINVAL))?;
281 let flags = control_block.aio_flags;
282
283 match op_type {
284 OpType::PRead | OpType::PReadV => {
285 if !file.can_read() {
286 return error!(EBADF);
287 }
288 }
289 OpType::PWrite | OpType::PWriteV => {
290 if !file.can_write() {
291 return error!(EBADF);
292 }
293 }
294 }
295 let mut buffers = match op_type {
296 OpType::PRead | OpType::PWrite => smallvec![UserBuffer {
297 address: control_block.aio_buf.into(),
298 length: control_block.aio_nbytes as usize,
299 }],
300 OpType::PReadV | OpType::PWriteV => {
301 let iovec_addr = IOVecPtr::new(current_task, control_block.aio_buf);
302 let count: i32 = control_block.aio_nbytes.try_into().map_err(|_| errno!(EINVAL))?;
303 current_task.read_iovec(iovec_addr, count.into())?
304 }
305 };
306
307 let buffer_length = UserBuffer::cap_buffers_to_max_rw_count(
309 current_task.maximum_valid_address().ok_or_else(|| errno!(EINVAL))?,
310 &mut buffers,
311 )?;
312 checked_add_offset_and_length(offset, buffer_length)?;
313
314 let eventfd = if flags & IOCB_FLAG_RESFD != 0 {
315 let eventfd = current_task
316 .live()
317 .files
318 .get(FdNumber::from_raw(control_block.aio_resfd as i32))?;
319 if eventfd.downcast_file::<EventFdFileObject>().is_none() {
320 return error!(EINVAL);
321 }
322 Some(Arc::downgrade(&eventfd))
323 } else {
324 None
325 };
326
327 Ok(IoOperation {
328 op_type,
329 file: Arc::downgrade(&file),
330 mm: current_task.mm()?.as_remote(),
331 buffers,
332 offset,
333 id: control_block.aio_data,
334 iocb_addr,
335 eventfd,
336 })
337 }
338
339 fn execute(
340 &self,
341 locked: &mut Locked<Unlocked>,
342 current_task: &CurrentTask,
343 ) -> Option<Result<SyscallResult, Errno>> {
344 let Some(file) = self.file.upgrade() else {
345 return None;
348 };
349
350 let result = match self.op_type {
351 OpType::PRead | OpType::PReadV => {
352 self.do_read(locked, current_task, file).map(Into::into)
353 }
354 OpType::PWrite | OpType::PWriteV => {
355 self.do_write(locked, current_task, file).map(Into::into)
356 }
357 };
358 Some(result)
359 }
360
361 fn complete(&self, result: Result<SyscallResult, Errno>) -> io_event {
362 let res = match result {
363 Ok(return_value) => return_value.value() as i64,
364 Err(errno) => errno.return_value() as i64,
365 };
366
367 io_event { data: self.id, obj: self.iocb_addr.addr().into(), res, ..Default::default() }
368 }
369
370 fn do_read(
371 &self,
372 locked: &mut Locked<Unlocked>,
373 current_task: &CurrentTask,
374 file: FileHandle,
375 ) -> Result<usize, Errno> {
376 let buffers = self.buffers.clone();
377 let mut output_buffer = {
378 let sink = UserBuffersOutputBuffer::remote_new(&self.mm, buffers.clone())?;
379 VecOutputBuffer::new(sink.available())
380 };
381
382 file.read_at(locked, current_task, self.offset, &mut output_buffer)?;
383
384 let mut sink = UserBuffersOutputBuffer::remote_new(&self.mm, buffers)?;
385 sink.write(&output_buffer.data())
386 }
387
388 fn do_write(
389 &self,
390 locked: &mut Locked<Unlocked>,
391 current_task: &CurrentTask,
392 file: FileHandle,
393 ) -> Result<usize, Errno> {
394 let mut input_buffer = {
395 let mut source = UserBuffersInputBuffer::remote_new(&self.mm, self.buffers.clone())?;
396 VecInputBuffer::new(&source.read_all()?)
397 };
398
399 file.write_at(locked, current_task, self.offset, &mut input_buffer)
400 }
401}
402
403enum IoAction {
404 Op(IoOperation),
405 Stop,
406}
407
408#[derive(Default)]
409struct PendingOperations {
410 is_stopped: bool,
411 read_ops: VecDeque<IoOperation>,
415 write_ops: VecDeque<IoOperation>,
416}
417
418impl PendingOperations {
419 fn ops_for(&mut self, worker_type: WorkerType) -> &mut VecDeque<IoOperation> {
420 match worker_type {
421 WorkerType::Read => &mut self.read_ops,
422 WorkerType::Write => &mut self.write_ops,
423 }
424 }
425
426 fn ops_len(&self) -> usize {
427 self.read_ops.len() + self.write_ops.len()
428 }
429}
430
431struct OperationQueue {
432 max_operations: usize,
433 pending: Mutex<PendingOperations>,
434 read_waiters: WaitQueue,
435 write_waiters: WaitQueue,
436}
437
438impl OperationQueue {
439 fn new(max_operations: usize) -> Self {
440 Self {
441 max_operations,
442 pending: Default::default(),
443 read_waiters: Default::default(),
444 write_waiters: Default::default(),
445 }
446 }
447
448 fn waiters_for(&self, worker_type: WorkerType) -> &WaitQueue {
449 match worker_type {
450 WorkerType::Read => &self.read_waiters,
451 WorkerType::Write => &self.write_waiters,
452 }
453 }
454
455 fn enqueue(&self, op: IoOperation) -> Result<(), Errno> {
456 let worker_type = op.op_type.worker_type();
457 {
458 let mut pending = self.pending.lock();
459 if pending.is_stopped {
460 return error!(EINVAL);
461 }
462 if pending.ops_len() >= self.max_operations {
463 return error!(EAGAIN);
464 }
465 pending.ops_for(worker_type).push_back(op);
466 }
467 self.waiters_for(worker_type).notify_unordered_count(1);
468 Ok(())
469 }
470
471 fn stop(&self) {
472 let mut pending = self.pending.lock();
473 pending.is_stopped = true;
474 pending.read_ops.clear();
475 pending.write_ops.clear();
476 self.read_waiters.notify_all();
477 self.write_waiters.notify_all();
478 }
479
480 fn dequeue(&self, worker_type: WorkerType) -> Option<IoAction> {
481 let mut pending = self.pending.lock();
482 if pending.is_stopped {
483 return Some(IoAction::Stop);
484 }
485 pending.ops_for(worker_type).pop_front().map(IoAction::Op)
486 }
487
488 fn remove(&self, worker_type: WorkerType, iocb_addr: IocbPtr) -> Result<(), Errno> {
489 {
490 let mut pending = self.pending.lock();
491 if pending.is_stopped {
492 return error!(EINVAL);
493 }
494 if let Some(idx) = pending
496 .ops_for(worker_type)
497 .iter()
498 .position(|value| value.iocb_addr.addr() == iocb_addr.addr())
499 {
500 pending.ops_for(worker_type).remove(idx);
501 } else {
502 return error!(EAGAIN);
503 }
504 }
505 Ok(())
506 }
507
508 fn block_until_dequeue(
509 &self,
510 current_task: &CurrentTask,
511 worker_type: WorkerType,
512 ) -> Result<IoAction, Errno> {
513 if let Some(action) = self.dequeue(worker_type) {
514 return Ok(action);
515 }
516 loop {
517 let event = InterruptibleEvent::new();
518 let (mut waiter, guard) = SimpleWaiter::new(&event);
519 self.waiters_for(worker_type).wait_async_simple(&mut waiter);
520 if let Some(action) = self.dequeue(worker_type) {
521 return Ok(action);
522 }
523 current_task.block_until(guard, zx::MonotonicInstant::INFINITE)?;
524 }
525 }
526}
527
528#[derive(Default)]
529struct ResultQueue {
530 waiters: WaitQueue,
531 events: Mutex<VecDeque<io_event>>,
532}
533
534impl ResultQueue {
535 fn enqueue(&self, event: io_event) {
536 self.events.lock().push_back(event);
537 self.waiters.notify_unordered_count(1);
538 }
539
540 fn dequeue(&self, limit: usize) -> Vec<io_event> {
541 let mut events = self.events.lock();
542 let len = std::cmp::min(events.len(), limit);
543 events.drain(..len).collect()
544 }
545}