1use crate::task::{CurrentTask, EventHandler, FullCredentials, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9 AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10 SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::errors::Errno;
14use starnix_uapi::open_flags::OpenFlags;
15use starnix_uapi::vfs::FdEvents;
16use starnix_uapi::{errno, error, ucred};
17
18pub struct VsockSocket {
22 inner: Mutex<VsockSocketInner>,
23}
24
25struct VsockSocketInner {
26 address: Option<SocketAddress>,
28
29 waiters: WaitQueue,
31
32 state: VsockSocketState,
34}
35
36enum VsockSocketState {
37 Disconnected,
39
40 Listening(AcceptQueue),
42
43 Connected { file: FileHandle, peer_addr: SocketAddress },
45
46 Closed,
48}
49
50fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
51 socket.downcast_socket::<VsockSocket>().unwrap()
56}
57
58impl VsockSocket {
59 pub fn new(_socket_type: SocketType) -> VsockSocket {
60 VsockSocket {
61 inner: Mutex::new(VsockSocketInner {
62 address: None,
63 waiters: WaitQueue::default(),
64 state: VsockSocketState::Disconnected,
65 }),
66 }
67 }
68
69 fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
71 self.inner.lock()
72 }
73}
74
75impl SocketOps for VsockSocket {
76 fn connect(
79 &self,
80 _locked: &mut Locked<FileOpsCore>,
81 _socket: &SocketHandle,
82 _current_task: &CurrentTask,
83 _peer: SocketPeer,
84 ) -> Result<(), Errno> {
85 error!(EPROTOTYPE)
86 }
87
88 fn listen(
89 &self,
90 _locked: &mut Locked<FileOpsCore>,
91 _socket: &Socket,
92 backlog: i32,
93 _credentials: ucred,
94 ) -> Result<(), Errno> {
95 let mut inner = self.lock();
96 let is_bound = inner.address.is_some();
97 let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
98 match &mut inner.state {
99 VsockSocketState::Disconnected if is_bound => {
100 inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
101 Ok(())
102 }
103 VsockSocketState::Listening(queue) => {
104 queue.set_backlog(backlog)?;
105 Ok(())
106 }
107 _ => error!(EINVAL),
108 }
109 }
110
111 fn accept(
112 &self,
113 _locked: &mut Locked<FileOpsCore>,
114 socket: &Socket,
115 _current_task: &CurrentTask,
116 ) -> Result<SocketHandle, Errno> {
117 match socket.socket_type {
118 SocketType::Stream | SocketType::SeqPacket => {}
119 _ => return error!(EOPNOTSUPP),
120 }
121 let mut inner = self.lock();
122 let queue = match &mut inner.state {
123 VsockSocketState::Listening(queue) => queue,
124 _ => return error!(EINVAL),
125 };
126 let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
127 Ok(socket)
128 }
129
130 fn bind(
131 &self,
132 _locked: &mut Locked<FileOpsCore>,
133 _socket: &Socket,
134 _current_task: &CurrentTask,
135 socket_address: SocketAddress,
136 ) -> Result<(), Errno> {
137 match socket_address {
138 SocketAddress::Vsock { .. } => {}
139 _ => return error!(EINVAL),
140 }
141 let mut inner = self.lock();
142 if inner.address.is_some() {
143 return error!(EINVAL);
144 }
145 inner.address = Some(socket_address);
146 Ok(())
147 }
148
149 fn read(
150 &self,
151 locked: &mut Locked<FileOpsCore>,
152 _socket: &Socket,
153 current_task: &CurrentTask,
154 data: &mut dyn OutputBuffer,
155 _flags: SocketMessageFlags,
156 ) -> Result<MessageReadInfo, Errno> {
157 let (address, file) = {
158 let inner = self.lock();
159 let address = inner.address.clone();
160
161 match &inner.state {
162 VsockSocketState::Connected { file, .. } => (address, file.clone()),
163 _ => return error!(EBADF),
164 }
165 };
166 let bytes_read = current_task.override_creds(
167 |creds| *creds = FullCredentials::for_kernel(),
168 || file.read(locked, current_task, data),
169 )?;
170 Ok(MessageReadInfo {
171 bytes_read,
172 message_length: bytes_read,
173 address,
174 ancillary_data: vec![],
175 })
176 }
177
178 fn write(
179 &self,
180 locked: &mut Locked<FileOpsCore>,
181 _socket: &Socket,
182 current_task: &CurrentTask,
183 data: &mut dyn InputBuffer,
184 _dest_address: &mut Option<SocketAddress>,
185 _ancillary_data: &mut Vec<AncillaryData>,
186 ) -> Result<usize, Errno> {
187 let file = {
188 let inner = self.lock();
189 match &inner.state {
190 VsockSocketState::Connected { file, .. } => file.clone(),
191 _ => return error!(EBADF),
192 }
193 };
194 current_task.override_creds(
195 |creds| *creds = FullCredentials::for_kernel(),
196 || file.write(locked, current_task, data),
197 )
198 }
199
200 fn wait_async(
201 &self,
202 locked: &mut Locked<FileOpsCore>,
203 _socket: &Socket,
204 current_task: &CurrentTask,
205 waiter: &Waiter,
206 events: FdEvents,
207 handler: EventHandler,
208 ) -> WaitCanceler {
209 let inner = self.lock();
210 match &inner.state {
211 VsockSocketState::Connected { file, .. } => file
212 .wait_async(locked, current_task, waiter, events, handler)
213 .expect("vsock socket should be connected to a file that can be waited on"),
214 _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
215 }
216 }
217
218 fn query_events(
219 &self,
220 locked: &mut Locked<FileOpsCore>,
221 _socket: &Socket,
222 current_task: &CurrentTask,
223 ) -> Result<FdEvents, Errno> {
224 self.lock().query_events(locked, current_task)
225 }
226
227 fn shutdown(
228 &self,
229 _locked: &mut Locked<FileOpsCore>,
230 _socket: &Socket,
231 _how: SocketShutdownFlags,
232 ) -> Result<(), Errno> {
233 self.lock().state = VsockSocketState::Closed;
234 Ok(())
235 }
236
237 fn close(
238 &self,
239 locked: &mut Locked<FileOpsCore>,
240 _current_task: &CurrentTask,
241 socket: &Socket,
242 ) {
243 self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
245 .unwrap();
246 }
247
248 fn getsockname(
249 &self,
250 _locked: &mut Locked<FileOpsCore>,
251 socket: &Socket,
252 ) -> Result<SocketAddress, Errno> {
253 let inner = self.lock();
254 if let Some(address) = &inner.address {
255 Ok(address.clone())
256 } else {
257 Ok(SocketAddress::default_for_domain(socket.domain))
258 }
259 }
260
261 fn getpeername(
262 &self,
263 _locked: &mut Locked<FileOpsCore>,
264 _socket: &Socket,
265 ) -> Result<SocketAddress, Errno> {
266 let inner = self.lock();
267 match &inner.state {
268 VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
269 _ => {
270 error!(ENOTCONN)
271 }
272 }
273 }
274}
275
276impl VsockSocket {
277 pub fn remote_connection<L>(
278 &self,
279 locked: &mut Locked<L>,
280 socket: &Socket,
281 current_task: &CurrentTask,
282 file: FileHandle,
283 ) -> Result<(), Errno>
284 where
285 L: LockEqualOrBefore<FileOpsCore>,
286 {
287 assert!(file.flags().contains(OpenFlags::NONBLOCK));
290 if socket.socket_type != SocketType::Stream {
291 return error!(ENOTSUP);
292 }
293 if socket.domain != SocketDomain::Vsock {
294 return error!(EINVAL);
295 }
296
297 let mut inner = self.lock();
298 match &mut inner.state {
299 VsockSocketState::Listening(queue) => {
300 if queue.sockets.len() >= queue.backlog {
301 return error!(EAGAIN);
302 }
303 let remote_socket = Socket::new(
304 locked,
305 current_task,
306 SocketDomain::Vsock,
307 SocketType::Stream,
308 SocketProtocol::default(),
309 false,
310 )?;
311 downcast_socket_to_vsock(&remote_socket).lock().state =
312 VsockSocketState::Connected {
313 file,
314 peer_addr: SocketAddress::Vsock {
315 port: u32::MAX,
316 cid: starnix_uapi::VMADDR_CID_HOST,
317 },
318 };
319 queue.sockets.push_back(remote_socket);
320 inner.waiters.notify_fd_events(FdEvents::POLLIN);
321 Ok(())
322 }
323 _ => error!(EINVAL),
324 }
325 }
326}
327
328impl VsockSocketInner {
329 fn query_events<L>(
330 &self,
331 locked: &mut Locked<L>,
332 current_task: &CurrentTask,
333 ) -> Result<FdEvents, Errno>
334 where
335 L: LockEqualOrBefore<FileOpsCore>,
336 {
337 Ok(match &self.state {
338 VsockSocketState::Disconnected => FdEvents::empty(),
339 VsockSocketState::Connected { file, .. } => current_task.override_creds(
340 |creds| *creds = FullCredentials::for_kernel(),
341 || file.query_events(locked, current_task),
342 )?,
343 VsockSocketState::Listening(queue) => {
344 if !queue.sockets.is_empty() {
345 FdEvents::POLLIN
346 } else {
347 FdEvents::empty()
348 }
349 }
350 VsockSocketState::Closed => FdEvents::POLLHUP,
351 })
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::fs::fuchsia::create_fuchsia_pipe;
359 use crate::mm::PAGE_SIZE;
360 use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
361 use crate::testing::spawn_kernel_and_run;
362 use crate::vfs::EpollFileObject;
363 use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
364 use crate::vfs::socket::SocketFile;
365 use futures::executor::block_on;
366 use starnix_sync::Unlocked;
367 use starnix_uapi::vfs::EpollEvent;
368 use syncio::Zxio;
369 use zx::HandleBased;
370
371 #[::fuchsia::test]
372 async fn test_vsock_socket() {
373 spawn_kernel_and_run(async |locked, current_task| {
374 let (fs1, fs2) = fidl::Socket::create_stream();
375 const VSOCK_PORT: u32 = 5555;
376
377 let listen_socket = Socket::new(
378 locked,
379 ¤t_task,
380 SocketDomain::Vsock,
381 SocketType::Stream,
382 SocketProtocol::default(),
383 false,
384 )
385 .expect("Failed to create socket.");
386 current_task
387 .abstract_vsock_namespace
388 .bind(locked, ¤t_task, VSOCK_PORT, &listen_socket)
389 .expect("Failed to bind socket.");
390 listen_socket.listen(locked, ¤t_task, 10).expect("Failed to listen.");
391
392 let listen_socket = current_task
393 .abstract_vsock_namespace
394 .lookup(&VSOCK_PORT)
395 .expect("Failed to look up listening socket.");
396 let remote = create_fuchsia_pipe(
397 locked,
398 ¤t_task,
399 fs2,
400 OpenFlags::RDWR | OpenFlags::NONBLOCK,
401 )
402 .unwrap();
403 listen_socket
404 .downcast_socket::<VsockSocket>()
405 .unwrap()
406 .remote_connection(locked, &listen_socket, ¤t_task, remote)
407 .unwrap();
408
409 let server_socket = listen_socket.accept(locked, ¤t_task).unwrap();
410
411 let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
412 assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
413 let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
414 let read_message_info = server_socket
415 .read(locked, ¤t_task, &mut buffer_iterator, SocketMessageFlags::empty())
416 .unwrap();
417 assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
418 assert_eq!(buffer_iterator.data(), test_bytes_in);
419
420 let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
421 let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
422 server_socket
423 .write(locked, ¤t_task, &mut buffer_iterator, &mut None, &mut vec![])
424 .unwrap();
425 assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
426
427 let mut read_back_buf = [0u8; 100];
428 assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
429 assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
430
431 server_socket.close(locked, ¤t_task);
432 listen_socket.close(locked, ¤t_task);
433 })
434 .await;
435 }
436
437 #[::fuchsia::test]
438 async fn test_vsock_write_while_read() {
439 spawn_kernel_and_run(async |locked, current_task| {
440 let kernel = current_task.kernel();
441 let (fs1, fs2) = fidl::Socket::create_stream();
442 let socket = Socket::new(
443 locked,
444 ¤t_task,
445 SocketDomain::Vsock,
446 SocketType::Stream,
447 SocketProtocol::default(),
448 false,
449 )
450 .expect("Failed to create socket.");
451 let remote = create_fuchsia_pipe(
452 locked,
453 ¤t_task,
454 fs2,
455 OpenFlags::RDWR | OpenFlags::NONBLOCK,
456 )
457 .unwrap();
458 downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
459 file: remote,
460 peer_addr: SocketAddress::Vsock {
461 port: u32::MAX,
462 cid: starnix_uapi::VMADDR_CID_HOST,
463 },
464 };
465 let socket_file =
466 SocketFile::from_socket(locked, ¤t_task, socket, OpenFlags::RDWR, false)
467 .expect("Failed to create socket file.");
468
469 const XFER_SIZE: usize = 42;
470
471 let socket_clone = socket_file.clone();
472 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
473 let bytes_read = socket_clone
474 .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
475 .unwrap();
476 assert_eq!(XFER_SIZE, bytes_read);
477 };
478 let (result, req) =
479 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
480 kernel.kthreads.spawner().spawn_from_request(req);
481
482 std::thread::sleep(std::time::Duration::from_secs(2));
484
485 socket_file
486 .write(locked, ¤t_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
487 .unwrap();
488
489 let mut buffer = [0u8; 1024];
490 assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
491 assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
492 block_on(result).unwrap();
493 })
494 .await;
495 }
496
497 #[::fuchsia::test]
498 async fn test_vsock_poll() {
499 spawn_kernel_and_run(async |locked, current_task| {
500 let (client, server) = zx::Socket::create_stream();
501 let pipe = create_fuchsia_pipe(locked, ¤t_task, client, OpenFlags::RDWR)
502 .expect("create_fuchsia_pipe");
503 let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
504 let socket_object = Socket::new(
505 locked,
506 ¤t_task,
507 SocketDomain::Vsock,
508 SocketType::Stream,
509 SocketProtocol::default(),
510 false,
511 )
512 .expect("Failed to create socket.");
513 downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
514 file: pipe,
515 peer_addr: SocketAddress::Vsock {
516 port: u32::MAX,
517 cid: starnix_uapi::VMADDR_CID_HOST,
518 },
519 };
520 let socket = SocketFile::from_socket(
521 locked,
522 ¤t_task,
523 socket_object.clone(),
524 OpenFlags::RDWR,
525 false,
526 )
527 .expect("Failed to create socket file.");
528
529 assert_eq!(
530 socket.query_events(locked, ¤t_task),
531 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
532 );
533
534 let epoll_object = EpollFileObject::new_file(locked, ¤t_task);
535 let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
536 let event = EpollEvent::new(FdEvents::POLLIN, 0);
537 epoll_file
538 .add(locked, ¤t_task, &socket, &epoll_object, event)
539 .expect("poll_file.add");
540
541 let fds = epoll_file
542 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
543 .expect("wait");
544 assert!(fds.is_empty());
545
546 assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
547
548 assert_eq!(
549 socket.query_events(locked, ¤t_task),
550 Ok(FdEvents::POLLOUT
551 | FdEvents::POLLWRNORM
552 | FdEvents::POLLIN
553 | FdEvents::POLLRDNORM)
554 );
555 let fds = epoll_file
556 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
557 .expect("wait");
558 assert_eq!(fds.len(), 1);
559
560 assert_eq!(
561 socket.read(locked, ¤t_task, &mut VecOutputBuffer::new(64)).expect("read"),
562 1
563 );
564
565 assert_eq!(
566 socket.query_events(locked, ¤t_task),
567 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
568 );
569 let fds = epoll_file
570 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
571 .expect("wait");
572 assert!(fds.is_empty());
573 })
574 .await;
575 }
576}