1use crate::task::{CurrentTask, EventHandler, FullCredentials, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9 AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10 SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::errors::Errno;
14use starnix_uapi::open_flags::OpenFlags;
15use starnix_uapi::vfs::FdEvents;
16use starnix_uapi::{errno, error, ucred};
17
18pub struct VsockSocket {
22 inner: Mutex<VsockSocketInner>,
23}
24
25struct VsockSocketInner {
26 address: Option<SocketAddress>,
28
29 waiters: WaitQueue,
31
32 state: VsockSocketState,
34}
35
36enum VsockSocketState {
37 Disconnected,
39
40 Listening(AcceptQueue),
42
43 Connected { file: FileHandle, peer_addr: SocketAddress },
45
46 Closed,
48}
49
50fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
51 socket.downcast_socket::<VsockSocket>().unwrap()
56}
57
58impl VsockSocket {
59 pub fn new(_socket_type: SocketType) -> VsockSocket {
60 VsockSocket {
61 inner: Mutex::new(VsockSocketInner {
62 address: None,
63 waiters: WaitQueue::default(),
64 state: VsockSocketState::Disconnected,
65 }),
66 }
67 }
68
69 fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
71 self.inner.lock()
72 }
73}
74
75impl SocketOps for VsockSocket {
76 fn connect(
79 &self,
80 _locked: &mut Locked<FileOpsCore>,
81 _socket: &SocketHandle,
82 _current_task: &CurrentTask,
83 _peer: SocketPeer,
84 ) -> Result<(), Errno> {
85 error!(EPROTOTYPE)
86 }
87
88 fn listen(
89 &self,
90 _locked: &mut Locked<FileOpsCore>,
91 _socket: &Socket,
92 backlog: i32,
93 _credentials: ucred,
94 ) -> Result<(), Errno> {
95 let mut inner = self.lock();
96 let is_bound = inner.address.is_some();
97 let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
98 match &mut inner.state {
99 VsockSocketState::Disconnected if is_bound => {
100 inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
101 Ok(())
102 }
103 VsockSocketState::Listening(queue) => {
104 queue.set_backlog(backlog)?;
105 Ok(())
106 }
107 _ => error!(EINVAL),
108 }
109 }
110
111 fn accept(
112 &self,
113 _locked: &mut Locked<FileOpsCore>,
114 socket: &Socket,
115 _current_task: &CurrentTask,
116 ) -> Result<SocketHandle, Errno> {
117 match socket.socket_type {
118 SocketType::Stream | SocketType::SeqPacket => {}
119 _ => return error!(EOPNOTSUPP),
120 }
121 let mut inner = self.lock();
122 let queue = match &mut inner.state {
123 VsockSocketState::Listening(queue) => queue,
124 _ => return error!(EINVAL),
125 };
126 let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
127 Ok(socket)
128 }
129
130 fn bind(
131 &self,
132 _locked: &mut Locked<FileOpsCore>,
133 _socket: &Socket,
134 _current_task: &CurrentTask,
135 socket_address: SocketAddress,
136 ) -> Result<(), Errno> {
137 match socket_address {
138 SocketAddress::Vsock { .. } => {}
139 _ => return error!(EINVAL),
140 }
141 let mut inner = self.lock();
142 if inner.address.is_some() {
143 return error!(EINVAL);
144 }
145 inner.address = Some(socket_address);
146 Ok(())
147 }
148
149 fn read(
150 &self,
151 locked: &mut Locked<FileOpsCore>,
152 _socket: &Socket,
153 current_task: &CurrentTask,
154 data: &mut dyn OutputBuffer,
155 _flags: SocketMessageFlags,
156 ) -> Result<MessageReadInfo, Errno> {
157 let (address, file) = {
158 let inner = self.lock();
159 let address = inner.address.clone();
160
161 match &inner.state {
162 VsockSocketState::Connected { file, .. } => (address, file.clone()),
163 _ => return error!(EBADF),
164 }
165 };
166 let bytes_read = current_task.override_creds(FullCredentials::for_kernel(), || {
167 file.read(locked, current_task, data)
168 })?;
169 Ok(MessageReadInfo {
170 bytes_read,
171 message_length: bytes_read,
172 address,
173 ancillary_data: vec![],
174 })
175 }
176
177 fn write(
178 &self,
179 locked: &mut Locked<FileOpsCore>,
180 _socket: &Socket,
181 current_task: &CurrentTask,
182 data: &mut dyn InputBuffer,
183 _dest_address: &mut Option<SocketAddress>,
184 _ancillary_data: &mut Vec<AncillaryData>,
185 ) -> Result<usize, Errno> {
186 let file = {
187 let inner = self.lock();
188 match &inner.state {
189 VsockSocketState::Connected { file, .. } => file.clone(),
190 _ => return error!(EBADF),
191 }
192 };
193 current_task.override_creds(FullCredentials::for_kernel(), || {
194 file.write(locked, current_task, data)
195 })
196 }
197
198 fn wait_async(
199 &self,
200 locked: &mut Locked<FileOpsCore>,
201 _socket: &Socket,
202 current_task: &CurrentTask,
203 waiter: &Waiter,
204 events: FdEvents,
205 handler: EventHandler,
206 ) -> WaitCanceler {
207 let inner = self.lock();
208 match &inner.state {
209 VsockSocketState::Connected { file, .. } => file
210 .wait_async(locked, current_task, waiter, events, handler)
211 .expect("vsock socket should be connected to a file that can be waited on"),
212 _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
213 }
214 }
215
216 fn query_events(
217 &self,
218 locked: &mut Locked<FileOpsCore>,
219 _socket: &Socket,
220 current_task: &CurrentTask,
221 ) -> Result<FdEvents, Errno> {
222 self.lock().query_events(locked, current_task)
223 }
224
225 fn shutdown(
226 &self,
227 _locked: &mut Locked<FileOpsCore>,
228 _socket: &Socket,
229 _how: SocketShutdownFlags,
230 ) -> Result<(), Errno> {
231 self.lock().state = VsockSocketState::Closed;
232 Ok(())
233 }
234
235 fn close(
236 &self,
237 locked: &mut Locked<FileOpsCore>,
238 _current_task: &CurrentTask,
239 socket: &Socket,
240 ) {
241 self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
243 .unwrap();
244 }
245
246 fn getsockname(
247 &self,
248 _locked: &mut Locked<FileOpsCore>,
249 socket: &Socket,
250 ) -> Result<SocketAddress, Errno> {
251 let inner = self.lock();
252 if let Some(address) = &inner.address {
253 Ok(address.clone())
254 } else {
255 Ok(SocketAddress::default_for_domain(socket.domain))
256 }
257 }
258
259 fn getpeername(
260 &self,
261 _locked: &mut Locked<FileOpsCore>,
262 _socket: &Socket,
263 ) -> Result<SocketAddress, Errno> {
264 let inner = self.lock();
265 match &inner.state {
266 VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
267 _ => {
268 error!(ENOTCONN)
269 }
270 }
271 }
272}
273
274impl VsockSocket {
275 pub fn remote_connection<L>(
276 &self,
277 locked: &mut Locked<L>,
278 socket: &Socket,
279 current_task: &CurrentTask,
280 file: FileHandle,
281 ) -> Result<(), Errno>
282 where
283 L: LockEqualOrBefore<FileOpsCore>,
284 {
285 assert!(file.flags().contains(OpenFlags::NONBLOCK));
288 if socket.socket_type != SocketType::Stream {
289 return error!(ENOTSUP);
290 }
291 if socket.domain != SocketDomain::Vsock {
292 return error!(EINVAL);
293 }
294
295 let mut inner = self.lock();
296 match &mut inner.state {
297 VsockSocketState::Listening(queue) => {
298 if queue.sockets.len() >= queue.backlog {
299 return error!(EAGAIN);
300 }
301 let remote_socket = Socket::new(
302 locked,
303 current_task,
304 SocketDomain::Vsock,
305 SocketType::Stream,
306 SocketProtocol::default(),
307 false,
308 )?;
309 downcast_socket_to_vsock(&remote_socket).lock().state =
310 VsockSocketState::Connected {
311 file,
312 peer_addr: SocketAddress::Vsock {
313 port: u32::MAX,
314 cid: starnix_uapi::VMADDR_CID_HOST,
315 },
316 };
317 queue.sockets.push_back(remote_socket);
318 inner.waiters.notify_fd_events(FdEvents::POLLIN);
319 Ok(())
320 }
321 _ => error!(EINVAL),
322 }
323 }
324}
325
326impl VsockSocketInner {
327 fn query_events<L>(
328 &self,
329 locked: &mut Locked<L>,
330 current_task: &CurrentTask,
331 ) -> Result<FdEvents, Errno>
332 where
333 L: LockEqualOrBefore<FileOpsCore>,
334 {
335 Ok(match &self.state {
336 VsockSocketState::Disconnected => FdEvents::empty(),
337 VsockSocketState::Connected { file, .. } => current_task
338 .override_creds(FullCredentials::for_kernel(), || {
339 file.query_events(locked, current_task)
340 })?,
341 VsockSocketState::Listening(queue) => {
342 if !queue.sockets.is_empty() {
343 FdEvents::POLLIN
344 } else {
345 FdEvents::empty()
346 }
347 }
348 VsockSocketState::Closed => FdEvents::POLLHUP,
349 })
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::fs::fuchsia::create_fuchsia_pipe;
357 use crate::mm::PAGE_SIZE;
358 use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
359 use crate::testing::spawn_kernel_and_run;
360 use crate::vfs::EpollFileObject;
361 use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
362 use crate::vfs::socket::SocketFile;
363 use futures::executor::block_on;
364 use starnix_sync::Unlocked;
365 use starnix_uapi::vfs::EpollEvent;
366 use syncio::Zxio;
367 use zx::HandleBased;
368
369 #[::fuchsia::test]
370 async fn test_vsock_socket() {
371 spawn_kernel_and_run(async |locked, current_task| {
372 let (fs1, fs2) = fidl::Socket::create_stream();
373 const VSOCK_PORT: u32 = 5555;
374
375 let listen_socket = Socket::new(
376 locked,
377 ¤t_task,
378 SocketDomain::Vsock,
379 SocketType::Stream,
380 SocketProtocol::default(),
381 false,
382 )
383 .expect("Failed to create socket.");
384 current_task
385 .abstract_vsock_namespace
386 .bind(locked, ¤t_task, VSOCK_PORT, &listen_socket)
387 .expect("Failed to bind socket.");
388 listen_socket.listen(locked, ¤t_task, 10).expect("Failed to listen.");
389
390 let listen_socket = current_task
391 .abstract_vsock_namespace
392 .lookup(&VSOCK_PORT)
393 .expect("Failed to look up listening socket.");
394 let remote = create_fuchsia_pipe(
395 locked,
396 ¤t_task,
397 fs2,
398 OpenFlags::RDWR | OpenFlags::NONBLOCK,
399 )
400 .unwrap();
401 listen_socket
402 .downcast_socket::<VsockSocket>()
403 .unwrap()
404 .remote_connection(locked, &listen_socket, ¤t_task, remote)
405 .unwrap();
406
407 let server_socket = listen_socket.accept(locked, ¤t_task).unwrap();
408
409 let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
410 assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
411 let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
412 let read_message_info = server_socket
413 .read(locked, ¤t_task, &mut buffer_iterator, SocketMessageFlags::empty())
414 .unwrap();
415 assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
416 assert_eq!(buffer_iterator.data(), test_bytes_in);
417
418 let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
419 let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
420 server_socket
421 .write(locked, ¤t_task, &mut buffer_iterator, &mut None, &mut vec![])
422 .unwrap();
423 assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
424
425 let mut read_back_buf = [0u8; 100];
426 assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
427 assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
428
429 server_socket.close(locked, ¤t_task);
430 listen_socket.close(locked, ¤t_task);
431 })
432 .await;
433 }
434
435 #[::fuchsia::test]
436 async fn test_vsock_write_while_read() {
437 spawn_kernel_and_run(async |locked, current_task| {
438 let kernel = current_task.kernel();
439 let (fs1, fs2) = fidl::Socket::create_stream();
440 let socket = Socket::new(
441 locked,
442 ¤t_task,
443 SocketDomain::Vsock,
444 SocketType::Stream,
445 SocketProtocol::default(),
446 false,
447 )
448 .expect("Failed to create socket.");
449 let remote = create_fuchsia_pipe(
450 locked,
451 ¤t_task,
452 fs2,
453 OpenFlags::RDWR | OpenFlags::NONBLOCK,
454 )
455 .unwrap();
456 downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
457 file: remote,
458 peer_addr: SocketAddress::Vsock {
459 port: u32::MAX,
460 cid: starnix_uapi::VMADDR_CID_HOST,
461 },
462 };
463 let socket_file =
464 SocketFile::from_socket(locked, ¤t_task, socket, OpenFlags::RDWR, false)
465 .expect("Failed to create socket file.");
466
467 const XFER_SIZE: usize = 42;
468
469 let socket_clone = socket_file.clone();
470 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
471 let bytes_read = socket_clone
472 .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
473 .unwrap();
474 assert_eq!(XFER_SIZE, bytes_read);
475 };
476 let (result, req) =
477 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
478 kernel.kthreads.spawner().spawn_from_request(req);
479
480 std::thread::sleep(std::time::Duration::from_secs(2));
482
483 socket_file
484 .write(locked, ¤t_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
485 .unwrap();
486
487 let mut buffer = [0u8; 1024];
488 assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
489 assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
490 block_on(result).unwrap();
491 })
492 .await;
493 }
494
495 #[::fuchsia::test]
496 async fn test_vsock_poll() {
497 spawn_kernel_and_run(async |locked, current_task| {
498 let (client, server) = zx::Socket::create_stream();
499 let pipe = create_fuchsia_pipe(locked, ¤t_task, client, OpenFlags::RDWR)
500 .expect("create_fuchsia_pipe");
501 let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
502 let socket_object = Socket::new(
503 locked,
504 ¤t_task,
505 SocketDomain::Vsock,
506 SocketType::Stream,
507 SocketProtocol::default(),
508 false,
509 )
510 .expect("Failed to create socket.");
511 downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
512 file: pipe,
513 peer_addr: SocketAddress::Vsock {
514 port: u32::MAX,
515 cid: starnix_uapi::VMADDR_CID_HOST,
516 },
517 };
518 let socket = SocketFile::from_socket(
519 locked,
520 ¤t_task,
521 socket_object.clone(),
522 OpenFlags::RDWR,
523 false,
524 )
525 .expect("Failed to create socket file.");
526
527 assert_eq!(
528 socket.query_events(locked, ¤t_task),
529 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
530 );
531
532 let epoll_object = EpollFileObject::new_file(locked, ¤t_task);
533 let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
534 let event = EpollEvent::new(FdEvents::POLLIN, 0);
535 epoll_file
536 .add(locked, ¤t_task, &socket, &epoll_object, event)
537 .expect("poll_file.add");
538
539 let fds = epoll_file
540 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
541 .expect("wait");
542 assert!(fds.is_empty());
543
544 assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
545
546 assert_eq!(
547 socket.query_events(locked, ¤t_task),
548 Ok(FdEvents::POLLOUT
549 | FdEvents::POLLWRNORM
550 | FdEvents::POLLIN
551 | FdEvents::POLLRDNORM)
552 );
553 let fds = epoll_file
554 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
555 .expect("wait");
556 assert_eq!(fds.len(), 1);
557
558 assert_eq!(
559 socket.read(locked, ¤t_task, &mut VecOutputBuffer::new(64)).expect("read"),
560 1
561 );
562
563 assert_eq!(
564 socket.query_events(locked, ¤t_task),
565 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
566 );
567 let fds = epoll_file
568 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
569 .expect("wait");
570 assert!(fds.is_empty());
571 })
572 .await;
573 }
574}