1use crate::task::{CurrentTask, EventHandler, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9 AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10 SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::auth::Credentials;
14use starnix_uapi::errors::Errno;
15use starnix_uapi::open_flags::OpenFlags;
16use starnix_uapi::vfs::FdEvents;
17use starnix_uapi::{errno, error, ucred};
18
19pub struct VsockSocket {
23 inner: Mutex<VsockSocketInner>,
24}
25
26struct VsockSocketInner {
27 address: Option<SocketAddress>,
29
30 waiters: WaitQueue,
32
33 state: VsockSocketState,
35}
36
37enum VsockSocketState {
38 Disconnected,
40
41 Listening(AcceptQueue),
43
44 Connected { file: FileHandle, peer_addr: SocketAddress },
46
47 Closed,
49}
50
51fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
52 socket.downcast_socket::<VsockSocket>().unwrap()
57}
58
59impl VsockSocket {
60 pub fn new(_socket_type: SocketType) -> VsockSocket {
61 VsockSocket {
62 inner: Mutex::new(VsockSocketInner {
63 address: None,
64 waiters: WaitQueue::default(),
65 state: VsockSocketState::Disconnected,
66 }),
67 }
68 }
69
70 fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
72 self.inner.lock()
73 }
74}
75
76impl SocketOps for VsockSocket {
77 fn connect(
80 &self,
81 _locked: &mut Locked<FileOpsCore>,
82 _socket: &SocketHandle,
83 _current_task: &CurrentTask,
84 _peer: SocketPeer,
85 ) -> Result<(), Errno> {
86 error!(EPROTOTYPE)
87 }
88
89 fn listen(
90 &self,
91 _locked: &mut Locked<FileOpsCore>,
92 _socket: &Socket,
93 backlog: i32,
94 _credentials: ucred,
95 ) -> Result<(), Errno> {
96 let mut inner = self.lock();
97 let is_bound = inner.address.is_some();
98 let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
99 match &mut inner.state {
100 VsockSocketState::Disconnected if is_bound => {
101 inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
102 Ok(())
103 }
104 VsockSocketState::Listening(queue) => {
105 queue.set_backlog(backlog)?;
106 Ok(())
107 }
108 _ => error!(EINVAL),
109 }
110 }
111
112 fn accept(
113 &self,
114 _locked: &mut Locked<FileOpsCore>,
115 socket: &Socket,
116 _current_task: &CurrentTask,
117 ) -> Result<SocketHandle, Errno> {
118 match socket.socket_type {
119 SocketType::Stream | SocketType::SeqPacket => {}
120 _ => return error!(EOPNOTSUPP),
121 }
122 let mut inner = self.lock();
123 let queue = match &mut inner.state {
124 VsockSocketState::Listening(queue) => queue,
125 _ => return error!(EINVAL),
126 };
127 let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
128 Ok(socket)
129 }
130
131 fn bind(
132 &self,
133 _locked: &mut Locked<FileOpsCore>,
134 _socket: &Socket,
135 _current_task: &CurrentTask,
136 socket_address: SocketAddress,
137 ) -> Result<(), Errno> {
138 match socket_address {
139 SocketAddress::Vsock { .. } => {}
140 _ => return error!(EINVAL),
141 }
142 let mut inner = self.lock();
143 if inner.address.is_some() {
144 return error!(EINVAL);
145 }
146 inner.address = Some(socket_address);
147 Ok(())
148 }
149
150 fn read(
151 &self,
152 locked: &mut Locked<FileOpsCore>,
153 _socket: &Socket,
154 current_task: &CurrentTask,
155 data: &mut dyn OutputBuffer,
156 _flags: SocketMessageFlags,
157 ) -> Result<MessageReadInfo, Errno> {
158 let (address, file) = {
159 let inner = self.lock();
160 let address = inner.address.clone();
161
162 match &inner.state {
163 VsockSocketState::Connected { file, .. } => (address, file.clone()),
164 _ => return error!(EBADF),
165 }
166 };
167 let bytes_read = current_task
168 .override_creds(Credentials::root(), || file.read(locked, current_task, data))?;
169 Ok(MessageReadInfo {
170 bytes_read,
171 message_length: bytes_read,
172 address,
173 ancillary_data: vec![],
174 })
175 }
176
177 fn write(
178 &self,
179 locked: &mut Locked<FileOpsCore>,
180 _socket: &Socket,
181 current_task: &CurrentTask,
182 data: &mut dyn InputBuffer,
183 _dest_address: &mut Option<SocketAddress>,
184 _ancillary_data: &mut Vec<AncillaryData>,
185 ) -> Result<usize, Errno> {
186 let file = {
187 let inner = self.lock();
188 match &inner.state {
189 VsockSocketState::Connected { file, .. } => file.clone(),
190 _ => return error!(EBADF),
191 }
192 };
193 current_task.override_creds(Credentials::root(), || file.write(locked, current_task, data))
194 }
195
196 fn wait_async(
197 &self,
198 locked: &mut Locked<FileOpsCore>,
199 _socket: &Socket,
200 current_task: &CurrentTask,
201 waiter: &Waiter,
202 events: FdEvents,
203 handler: EventHandler,
204 ) -> WaitCanceler {
205 let inner = self.lock();
206 match &inner.state {
207 VsockSocketState::Connected { file, .. } => file
208 .wait_async(locked, current_task, waiter, events, handler)
209 .expect("vsock socket should be connected to a file that can be waited on"),
210 _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
211 }
212 }
213
214 fn query_events(
215 &self,
216 locked: &mut Locked<FileOpsCore>,
217 _socket: &Socket,
218 current_task: &CurrentTask,
219 ) -> Result<FdEvents, Errno> {
220 self.lock().query_events(locked, current_task)
221 }
222
223 fn shutdown(
224 &self,
225 _locked: &mut Locked<FileOpsCore>,
226 _socket: &Socket,
227 _how: SocketShutdownFlags,
228 ) -> Result<(), Errno> {
229 self.lock().state = VsockSocketState::Closed;
230 Ok(())
231 }
232
233 fn close(
234 &self,
235 locked: &mut Locked<FileOpsCore>,
236 _current_task: &CurrentTask,
237 socket: &Socket,
238 ) {
239 self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
241 .unwrap();
242 }
243
244 fn getsockname(
245 &self,
246 _locked: &mut Locked<FileOpsCore>,
247 socket: &Socket,
248 ) -> Result<SocketAddress, Errno> {
249 let inner = self.lock();
250 if let Some(address) = &inner.address {
251 Ok(address.clone())
252 } else {
253 Ok(SocketAddress::default_for_domain(socket.domain))
254 }
255 }
256
257 fn getpeername(
258 &self,
259 _locked: &mut Locked<FileOpsCore>,
260 _socket: &Socket,
261 ) -> Result<SocketAddress, Errno> {
262 let inner = self.lock();
263 match &inner.state {
264 VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
265 _ => {
266 error!(ENOTCONN)
267 }
268 }
269 }
270}
271
272impl VsockSocket {
273 pub fn remote_connection<L>(
274 &self,
275 locked: &mut Locked<L>,
276 socket: &Socket,
277 current_task: &CurrentTask,
278 file: FileHandle,
279 ) -> Result<(), Errno>
280 where
281 L: LockEqualOrBefore<FileOpsCore>,
282 {
283 assert!(file.flags().contains(OpenFlags::NONBLOCK));
286 if socket.socket_type != SocketType::Stream {
287 return error!(ENOTSUP);
288 }
289 if socket.domain != SocketDomain::Vsock {
290 return error!(EINVAL);
291 }
292
293 let mut inner = self.lock();
294 match &mut inner.state {
295 VsockSocketState::Listening(queue) => {
296 if queue.sockets.len() >= queue.backlog {
297 return error!(EAGAIN);
298 }
299 let remote_socket = Socket::new(
300 locked,
301 current_task,
302 SocketDomain::Vsock,
303 SocketType::Stream,
304 SocketProtocol::default(),
305 true,
306 )?;
307 downcast_socket_to_vsock(&remote_socket).lock().state =
308 VsockSocketState::Connected {
309 file,
310 peer_addr: SocketAddress::Vsock {
311 port: u32::MAX,
312 cid: starnix_uapi::VMADDR_CID_HOST,
313 },
314 };
315 queue.sockets.push_back(remote_socket);
316 inner.waiters.notify_fd_events(FdEvents::POLLIN);
317 Ok(())
318 }
319 _ => error!(EINVAL),
320 }
321 }
322}
323
324impl VsockSocketInner {
325 fn query_events<L>(
326 &self,
327 locked: &mut Locked<L>,
328 current_task: &CurrentTask,
329 ) -> Result<FdEvents, Errno>
330 where
331 L: LockEqualOrBefore<FileOpsCore>,
332 {
333 Ok(match &self.state {
334 VsockSocketState::Disconnected => FdEvents::empty(),
335 VsockSocketState::Connected { file, .. } => current_task
336 .override_creds(Credentials::root(), || file.query_events(locked, current_task))?,
337 VsockSocketState::Listening(queue) => {
338 if !queue.sockets.is_empty() {
339 FdEvents::POLLIN
340 } else {
341 FdEvents::empty()
342 }
343 }
344 VsockSocketState::Closed => FdEvents::POLLHUP,
345 })
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::fs::fuchsia::create_fuchsia_pipe;
353 use crate::mm::PAGE_SIZE;
354 use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
355 use crate::testing::spawn_kernel_and_run;
356 use crate::vfs::EpollFileObject;
357 use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
358 use crate::vfs::socket::SocketFile;
359 use futures::executor::block_on;
360 use starnix_sync::Unlocked;
361 use starnix_uapi::vfs::EpollEvent;
362 use syncio::Zxio;
363 use zx::HandleBased;
364
365 #[::fuchsia::test]
366 async fn test_vsock_socket() {
367 spawn_kernel_and_run(async |locked, current_task| {
368 let (fs1, fs2) = fidl::Socket::create_stream();
369 const VSOCK_PORT: u32 = 5555;
370
371 let listen_socket = Socket::new(
372 locked,
373 ¤t_task,
374 SocketDomain::Vsock,
375 SocketType::Stream,
376 SocketProtocol::default(),
377 false,
378 )
379 .expect("Failed to create socket.");
380 let vsock_ns = current_task.live().abstract_vsock_namespace.clone();
381 vsock_ns
382 .bind(locked, ¤t_task, VSOCK_PORT, &listen_socket)
383 .expect("Failed to bind socket.");
384 listen_socket.listen(locked, ¤t_task, 10).expect("Failed to listen.");
385
386 let listen_socket =
387 vsock_ns.lookup(&VSOCK_PORT).expect("Failed to look up listening socket.");
388 let remote = create_fuchsia_pipe(
389 locked,
390 ¤t_task,
391 fs2,
392 OpenFlags::RDWR | OpenFlags::NONBLOCK,
393 )
394 .unwrap();
395 listen_socket
396 .downcast_socket::<VsockSocket>()
397 .unwrap()
398 .remote_connection(locked, &listen_socket, ¤t_task, remote)
399 .unwrap();
400
401 let server_socket = listen_socket.accept(locked, ¤t_task).unwrap();
402
403 let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
404 assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
405 let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
406 let read_message_info = server_socket
407 .read(locked, ¤t_task, &mut buffer_iterator, SocketMessageFlags::empty())
408 .unwrap();
409 assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
410 assert_eq!(buffer_iterator.data(), test_bytes_in);
411
412 let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
413 let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
414 server_socket
415 .write(locked, ¤t_task, &mut buffer_iterator, &mut None, &mut vec![])
416 .unwrap();
417 assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
418
419 let mut read_back_buf = [0u8; 100];
420 assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
421 assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
422
423 server_socket.close(locked, ¤t_task);
424 listen_socket.close(locked, ¤t_task);
425 })
426 .await;
427 }
428
429 #[::fuchsia::test]
430 async fn test_vsock_write_while_read() {
431 spawn_kernel_and_run(async |locked, current_task| {
432 let kernel = current_task.kernel();
433 let (fs1, fs2) = fidl::Socket::create_stream();
434 let socket = Socket::new(
435 locked,
436 ¤t_task,
437 SocketDomain::Vsock,
438 SocketType::Stream,
439 SocketProtocol::default(),
440 false,
441 )
442 .expect("Failed to create socket.");
443 let remote = create_fuchsia_pipe(
444 locked,
445 ¤t_task,
446 fs2,
447 OpenFlags::RDWR | OpenFlags::NONBLOCK,
448 )
449 .unwrap();
450 downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
451 file: remote,
452 peer_addr: SocketAddress::Vsock {
453 port: u32::MAX,
454 cid: starnix_uapi::VMADDR_CID_HOST,
455 },
456 };
457 let socket_file =
458 SocketFile::from_socket(locked, ¤t_task, socket, OpenFlags::RDWR, false)
459 .expect("Failed to create socket file.");
460
461 const XFER_SIZE: usize = 42;
462
463 let socket_clone = socket_file.clone();
464 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
465 let bytes_read = socket_clone
466 .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
467 .unwrap();
468 assert_eq!(XFER_SIZE, bytes_read);
469 };
470 let (result, req) =
471 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
472 kernel.kthreads.spawner().spawn_from_request(req);
473
474 std::thread::sleep(std::time::Duration::from_secs(2));
476
477 socket_file
478 .write(locked, ¤t_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
479 .unwrap();
480
481 let mut buffer = [0u8; 1024];
482 assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
483 assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
484 block_on(result).unwrap();
485 })
486 .await;
487 }
488
489 #[::fuchsia::test]
490 async fn test_vsock_poll() {
491 spawn_kernel_and_run(async |locked, current_task| {
492 let (client, server) = zx::Socket::create_stream();
493 let pipe = create_fuchsia_pipe(locked, ¤t_task, client, OpenFlags::RDWR)
494 .expect("create_fuchsia_pipe");
495 let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
496 let socket_object = Socket::new(
497 locked,
498 ¤t_task,
499 SocketDomain::Vsock,
500 SocketType::Stream,
501 SocketProtocol::default(),
502 false,
503 )
504 .expect("Failed to create socket.");
505 downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
506 file: pipe,
507 peer_addr: SocketAddress::Vsock {
508 port: u32::MAX,
509 cid: starnix_uapi::VMADDR_CID_HOST,
510 },
511 };
512 let socket = SocketFile::from_socket(
513 locked,
514 ¤t_task,
515 socket_object.clone(),
516 OpenFlags::RDWR,
517 false,
518 )
519 .expect("Failed to create socket file.");
520
521 assert_eq!(
522 socket.query_events(locked, ¤t_task),
523 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
524 );
525
526 let epoll_object = EpollFileObject::new_file(locked, ¤t_task);
527 let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
528 let event = EpollEvent::new(FdEvents::POLLIN, 0);
529 epoll_file
530 .add(locked, ¤t_task, &socket, &epoll_object, event)
531 .expect("poll_file.add");
532
533 let fds = epoll_file
534 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
535 .expect("wait");
536 assert!(fds.is_empty());
537
538 assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
539
540 assert_eq!(
541 socket.query_events(locked, ¤t_task),
542 Ok(FdEvents::POLLOUT
543 | FdEvents::POLLWRNORM
544 | FdEvents::POLLIN
545 | FdEvents::POLLRDNORM)
546 );
547 let fds = epoll_file
548 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
549 .expect("wait");
550 assert_eq!(fds.len(), 1);
551
552 assert_eq!(
553 socket.read(locked, ¤t_task, &mut VecOutputBuffer::new(64)).expect("read"),
554 1
555 );
556
557 assert_eq!(
558 socket.query_events(locked, ¤t_task),
559 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
560 );
561 let fds = epoll_file
562 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
563 .expect("wait");
564 assert!(fds.is_empty());
565 })
566 .await;
567 }
568}