1use crate::task::{CurrentTask, EventHandler, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9 AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10 SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::auth::Credentials;
14use starnix_uapi::errors::Errno;
15use starnix_uapi::open_flags::OpenFlags;
16use starnix_uapi::vfs::FdEvents;
17use starnix_uapi::{errno, error, ucred};
18
19pub struct VsockSocket {
23 inner: Mutex<VsockSocketInner>,
24}
25
26struct VsockSocketInner {
27 address: Option<SocketAddress>,
29
30 waiters: WaitQueue,
32
33 state: VsockSocketState,
35}
36
37enum VsockSocketState {
38 Disconnected,
40
41 Listening(AcceptQueue),
43
44 Connected { file: FileHandle, peer_addr: SocketAddress },
46
47 Closed,
49}
50
51fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
52 socket.downcast_socket::<VsockSocket>().unwrap()
57}
58
59impl VsockSocket {
60 pub fn new(_socket_type: SocketType) -> VsockSocket {
61 VsockSocket {
62 inner: Mutex::new(VsockSocketInner {
63 address: None,
64 waiters: WaitQueue::default(),
65 state: VsockSocketState::Disconnected,
66 }),
67 }
68 }
69
70 fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
72 self.inner.lock()
73 }
74}
75
76impl SocketOps for VsockSocket {
77 fn connect(
80 &self,
81 _locked: &mut Locked<FileOpsCore>,
82 _socket: &SocketHandle,
83 _current_task: &CurrentTask,
84 _peer: SocketPeer,
85 ) -> Result<(), Errno> {
86 error!(EPROTOTYPE)
87 }
88
89 fn listen(
90 &self,
91 _locked: &mut Locked<FileOpsCore>,
92 _socket: &Socket,
93 backlog: i32,
94 _credentials: ucred,
95 ) -> Result<(), Errno> {
96 let mut inner = self.lock();
97 let is_bound = inner.address.is_some();
98 let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
99 match &mut inner.state {
100 VsockSocketState::Disconnected if is_bound => {
101 inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
102 Ok(())
103 }
104 VsockSocketState::Listening(queue) => {
105 queue.set_backlog(backlog)?;
106 Ok(())
107 }
108 _ => error!(EINVAL),
109 }
110 }
111
112 fn accept(
113 &self,
114 _locked: &mut Locked<FileOpsCore>,
115 socket: &Socket,
116 _current_task: &CurrentTask,
117 ) -> Result<SocketHandle, Errno> {
118 match socket.socket_type {
119 SocketType::Stream | SocketType::SeqPacket => {}
120 _ => return error!(EOPNOTSUPP),
121 }
122 let mut inner = self.lock();
123 let queue = match &mut inner.state {
124 VsockSocketState::Listening(queue) => queue,
125 _ => return error!(EINVAL),
126 };
127 let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
128 Ok(socket)
129 }
130
131 fn bind(
132 &self,
133 _locked: &mut Locked<FileOpsCore>,
134 _socket: &Socket,
135 _current_task: &CurrentTask,
136 socket_address: SocketAddress,
137 ) -> Result<(), Errno> {
138 match socket_address {
139 SocketAddress::Vsock { .. } => {}
140 _ => return error!(EINVAL),
141 }
142 let mut inner = self.lock();
143 if inner.address.is_some() {
144 return error!(EINVAL);
145 }
146 inner.address = Some(socket_address);
147 Ok(())
148 }
149
150 fn read(
151 &self,
152 locked: &mut Locked<FileOpsCore>,
153 _socket: &Socket,
154 current_task: &CurrentTask,
155 data: &mut dyn OutputBuffer,
156 _flags: SocketMessageFlags,
157 ) -> Result<MessageReadInfo, Errno> {
158 let (address, file) = {
159 let inner = self.lock();
160 let address = inner.address.clone();
161
162 match &inner.state {
163 VsockSocketState::Connected { file, .. } => (address, file.clone()),
164 _ => return error!(EBADF),
165 }
166 };
167 let bytes_read = current_task
168 .override_creds(Credentials::root(), || file.read(locked, current_task, data))?;
169 Ok(MessageReadInfo {
170 bytes_read,
171 message_length: bytes_read,
172 address,
173 ancillary_data: vec![],
174 })
175 }
176
177 fn write(
178 &self,
179 locked: &mut Locked<FileOpsCore>,
180 _socket: &Socket,
181 current_task: &CurrentTask,
182 data: &mut dyn InputBuffer,
183 _dest_address: &mut Option<SocketAddress>,
184 _ancillary_data: &mut Vec<AncillaryData>,
185 ) -> Result<usize, Errno> {
186 let file = {
187 let inner = self.lock();
188 match &inner.state {
189 VsockSocketState::Connected { file, .. } => file.clone(),
190 _ => return error!(EBADF),
191 }
192 };
193 current_task.override_creds(Credentials::root(), || file.write(locked, current_task, data))
194 }
195
196 fn wait_async(
197 &self,
198 locked: &mut Locked<FileOpsCore>,
199 _socket: &Socket,
200 current_task: &CurrentTask,
201 waiter: &Waiter,
202 events: FdEvents,
203 handler: EventHandler,
204 ) -> WaitCanceler {
205 let inner = self.lock();
206 match &inner.state {
207 VsockSocketState::Connected { file, .. } => file
208 .wait_async(locked, current_task, waiter, events, handler)
209 .expect("vsock socket should be connected to a file that can be waited on"),
210 _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
211 }
212 }
213
214 fn query_events(
215 &self,
216 locked: &mut Locked<FileOpsCore>,
217 _socket: &Socket,
218 current_task: &CurrentTask,
219 ) -> Result<FdEvents, Errno> {
220 self.lock().query_events(locked, current_task)
221 }
222
223 fn shutdown(
224 &self,
225 _locked: &mut Locked<FileOpsCore>,
226 _socket: &Socket,
227 _how: SocketShutdownFlags,
228 ) -> Result<(), Errno> {
229 self.lock().state = VsockSocketState::Closed;
230 Ok(())
231 }
232
233 fn close(
234 &self,
235 locked: &mut Locked<FileOpsCore>,
236 _current_task: &CurrentTask,
237 socket: &Socket,
238 ) {
239 self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
241 .unwrap();
242 }
243
244 fn getsockname(
245 &self,
246 _locked: &mut Locked<FileOpsCore>,
247 socket: &Socket,
248 ) -> Result<SocketAddress, Errno> {
249 let inner = self.lock();
250 if let Some(address) = &inner.address {
251 Ok(address.clone())
252 } else {
253 Ok(SocketAddress::default_for_domain(socket.domain))
254 }
255 }
256
257 fn getpeername(
258 &self,
259 _locked: &mut Locked<FileOpsCore>,
260 _socket: &Socket,
261 ) -> Result<SocketAddress, Errno> {
262 let inner = self.lock();
263 match &inner.state {
264 VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
265 _ => {
266 error!(ENOTCONN)
267 }
268 }
269 }
270}
271
272impl VsockSocket {
273 pub fn remote_connection<L>(
274 &self,
275 locked: &mut Locked<L>,
276 socket: &Socket,
277 current_task: &CurrentTask,
278 file: FileHandle,
279 ) -> Result<(), Errno>
280 where
281 L: LockEqualOrBefore<FileOpsCore>,
282 {
283 assert!(file.flags().contains(OpenFlags::NONBLOCK));
286 if socket.socket_type != SocketType::Stream {
287 return error!(ENOTSUP);
288 }
289 if socket.domain != SocketDomain::Vsock {
290 return error!(EINVAL);
291 }
292
293 let mut inner = self.lock();
294 match &mut inner.state {
295 VsockSocketState::Listening(queue) => {
296 if queue.sockets.len() >= queue.backlog {
297 return error!(EAGAIN);
298 }
299 let remote_socket = Socket::new(
300 locked,
301 current_task,
302 SocketDomain::Vsock,
303 SocketType::Stream,
304 SocketProtocol::default(),
305 true,
306 )?;
307 downcast_socket_to_vsock(&remote_socket).lock().state =
308 VsockSocketState::Connected {
309 file,
310 peer_addr: SocketAddress::Vsock {
311 port: u32::MAX,
312 cid: starnix_uapi::VMADDR_CID_HOST,
313 },
314 };
315 queue.sockets.push_back(remote_socket);
316 inner.waiters.notify_fd_events(FdEvents::POLLIN);
317 Ok(())
318 }
319 _ => error!(EINVAL),
320 }
321 }
322}
323
324impl VsockSocketInner {
325 fn query_events<L>(
326 &self,
327 locked: &mut Locked<L>,
328 current_task: &CurrentTask,
329 ) -> Result<FdEvents, Errno>
330 where
331 L: LockEqualOrBefore<FileOpsCore>,
332 {
333 Ok(match &self.state {
334 VsockSocketState::Disconnected => FdEvents::empty(),
335 VsockSocketState::Connected { file, .. } => current_task
336 .override_creds(Credentials::root(), || file.query_events(locked, current_task))?,
337 VsockSocketState::Listening(queue) => {
338 if !queue.sockets.is_empty() {
339 FdEvents::POLLIN
340 } else {
341 FdEvents::empty()
342 }
343 }
344 VsockSocketState::Closed => FdEvents::POLLHUP,
345 })
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::fs::fuchsia::create_fuchsia_pipe;
353 use crate::mm::PAGE_SIZE;
354 use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
355 use crate::testing::spawn_kernel_and_run;
356 use crate::vfs::EpollFileObject;
357 use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
358 use crate::vfs::socket::SocketFile;
359 use futures::executor::block_on;
360 use starnix_sync::Unlocked;
361 use starnix_uapi::vfs::EpollEvent;
362 use syncio::Zxio;
363 use zx::HandleBased;
364
365 #[::fuchsia::test]
366 async fn test_vsock_socket() {
367 spawn_kernel_and_run(async |locked, current_task| {
368 let (fs1, fs2) = fidl::Socket::create_stream();
369 const VSOCK_PORT: u32 = 5555;
370
371 let listen_socket = Socket::new(
372 locked,
373 ¤t_task,
374 SocketDomain::Vsock,
375 SocketType::Stream,
376 SocketProtocol::default(),
377 false,
378 )
379 .expect("Failed to create socket.");
380 current_task
381 .abstract_vsock_namespace
382 .bind(locked, ¤t_task, VSOCK_PORT, &listen_socket)
383 .expect("Failed to bind socket.");
384 listen_socket.listen(locked, ¤t_task, 10).expect("Failed to listen.");
385
386 let listen_socket = current_task
387 .abstract_vsock_namespace
388 .lookup(&VSOCK_PORT)
389 .expect("Failed to look up listening socket.");
390 let remote = create_fuchsia_pipe(
391 locked,
392 ¤t_task,
393 fs2,
394 OpenFlags::RDWR | OpenFlags::NONBLOCK,
395 )
396 .unwrap();
397 listen_socket
398 .downcast_socket::<VsockSocket>()
399 .unwrap()
400 .remote_connection(locked, &listen_socket, ¤t_task, remote)
401 .unwrap();
402
403 let server_socket = listen_socket.accept(locked, ¤t_task).unwrap();
404
405 let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
406 assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
407 let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
408 let read_message_info = server_socket
409 .read(locked, ¤t_task, &mut buffer_iterator, SocketMessageFlags::empty())
410 .unwrap();
411 assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
412 assert_eq!(buffer_iterator.data(), test_bytes_in);
413
414 let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
415 let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
416 server_socket
417 .write(locked, ¤t_task, &mut buffer_iterator, &mut None, &mut vec![])
418 .unwrap();
419 assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
420
421 let mut read_back_buf = [0u8; 100];
422 assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
423 assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
424
425 server_socket.close(locked, ¤t_task);
426 listen_socket.close(locked, ¤t_task);
427 })
428 .await;
429 }
430
431 #[::fuchsia::test]
432 async fn test_vsock_write_while_read() {
433 spawn_kernel_and_run(async |locked, current_task| {
434 let kernel = current_task.kernel();
435 let (fs1, fs2) = fidl::Socket::create_stream();
436 let socket = Socket::new(
437 locked,
438 ¤t_task,
439 SocketDomain::Vsock,
440 SocketType::Stream,
441 SocketProtocol::default(),
442 false,
443 )
444 .expect("Failed to create socket.");
445 let remote = create_fuchsia_pipe(
446 locked,
447 ¤t_task,
448 fs2,
449 OpenFlags::RDWR | OpenFlags::NONBLOCK,
450 )
451 .unwrap();
452 downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
453 file: remote,
454 peer_addr: SocketAddress::Vsock {
455 port: u32::MAX,
456 cid: starnix_uapi::VMADDR_CID_HOST,
457 },
458 };
459 let socket_file =
460 SocketFile::from_socket(locked, ¤t_task, socket, OpenFlags::RDWR, false)
461 .expect("Failed to create socket file.");
462
463 const XFER_SIZE: usize = 42;
464
465 let socket_clone = socket_file.clone();
466 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
467 let bytes_read = socket_clone
468 .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
469 .unwrap();
470 assert_eq!(XFER_SIZE, bytes_read);
471 };
472 let (result, req) =
473 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
474 kernel.kthreads.spawner().spawn_from_request(req);
475
476 std::thread::sleep(std::time::Duration::from_secs(2));
478
479 socket_file
480 .write(locked, ¤t_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
481 .unwrap();
482
483 let mut buffer = [0u8; 1024];
484 assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
485 assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
486 block_on(result).unwrap();
487 })
488 .await;
489 }
490
491 #[::fuchsia::test]
492 async fn test_vsock_poll() {
493 spawn_kernel_and_run(async |locked, current_task| {
494 let (client, server) = zx::Socket::create_stream();
495 let pipe = create_fuchsia_pipe(locked, ¤t_task, client, OpenFlags::RDWR)
496 .expect("create_fuchsia_pipe");
497 let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
498 let socket_object = Socket::new(
499 locked,
500 ¤t_task,
501 SocketDomain::Vsock,
502 SocketType::Stream,
503 SocketProtocol::default(),
504 false,
505 )
506 .expect("Failed to create socket.");
507 downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
508 file: pipe,
509 peer_addr: SocketAddress::Vsock {
510 port: u32::MAX,
511 cid: starnix_uapi::VMADDR_CID_HOST,
512 },
513 };
514 let socket = SocketFile::from_socket(
515 locked,
516 ¤t_task,
517 socket_object.clone(),
518 OpenFlags::RDWR,
519 false,
520 )
521 .expect("Failed to create socket file.");
522
523 assert_eq!(
524 socket.query_events(locked, ¤t_task),
525 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
526 );
527
528 let epoll_object = EpollFileObject::new_file(locked, ¤t_task);
529 let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
530 let event = EpollEvent::new(FdEvents::POLLIN, 0);
531 epoll_file
532 .add(locked, ¤t_task, &socket, &epoll_object, event)
533 .expect("poll_file.add");
534
535 let fds = epoll_file
536 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
537 .expect("wait");
538 assert!(fds.is_empty());
539
540 assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
541
542 assert_eq!(
543 socket.query_events(locked, ¤t_task),
544 Ok(FdEvents::POLLOUT
545 | FdEvents::POLLWRNORM
546 | FdEvents::POLLIN
547 | FdEvents::POLLRDNORM)
548 );
549 let fds = epoll_file
550 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
551 .expect("wait");
552 assert_eq!(fds.len(), 1);
553
554 assert_eq!(
555 socket.read(locked, ¤t_task, &mut VecOutputBuffer::new(64)).expect("read"),
556 1
557 );
558
559 assert_eq!(
560 socket.query_events(locked, ¤t_task),
561 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
562 );
563 let fds = epoll_file
564 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
565 .expect("wait");
566 assert!(fds.is_empty());
567 })
568 .await;
569 }
570}