1use crate::task::{CurrentTask, EventHandler, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9 AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10 SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::auth::Credentials;
14use starnix_uapi::errors::Errno;
15use starnix_uapi::open_flags::OpenFlags;
16use starnix_uapi::vfs::FdEvents;
17use starnix_uapi::{errno, error, ucred};
18
19pub struct VsockSocket {
23 inner: Mutex<VsockSocketInner>,
24}
25
26struct VsockSocketInner {
27 address: Option<SocketAddress>,
29
30 waiters: WaitQueue,
32
33 state: VsockSocketState,
35}
36
37enum VsockSocketState {
38 Disconnected,
40
41 Listening(AcceptQueue),
43
44 Connected { file: FileHandle, peer_addr: SocketAddress },
46
47 Closed,
49}
50
51fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
52 socket.downcast_socket::<VsockSocket>().unwrap()
57}
58
59impl VsockSocket {
60 pub fn new(_socket_type: SocketType) -> VsockSocket {
61 VsockSocket {
62 inner: Mutex::new(VsockSocketInner {
63 address: None,
64 waiters: WaitQueue::default(),
65 state: VsockSocketState::Disconnected,
66 }),
67 }
68 }
69
70 fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
72 self.inner.lock()
73 }
74}
75
76impl SocketOps for VsockSocket {
77 fn connect(
80 &self,
81 _locked: &mut Locked<FileOpsCore>,
82 _socket: &SocketHandle,
83 _current_task: &CurrentTask,
84 _peer: SocketPeer,
85 ) -> Result<(), Errno> {
86 error!(EPROTOTYPE)
87 }
88
89 fn listen(
90 &self,
91 _locked: &mut Locked<FileOpsCore>,
92 _socket: &Socket,
93 backlog: i32,
94 _credentials: ucred,
95 ) -> Result<(), Errno> {
96 let mut inner = self.lock();
97 let is_bound = inner.address.is_some();
98 let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
99 match &mut inner.state {
100 VsockSocketState::Disconnected if is_bound => {
101 inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
102 Ok(())
103 }
104 VsockSocketState::Listening(queue) => {
105 queue.set_backlog(backlog)?;
106 Ok(())
107 }
108 _ => error!(EINVAL),
109 }
110 }
111
112 fn accept(
113 &self,
114 _locked: &mut Locked<FileOpsCore>,
115 socket: &Socket,
116 _current_task: &CurrentTask,
117 ) -> Result<SocketHandle, Errno> {
118 match socket.socket_type {
119 SocketType::Stream | SocketType::SeqPacket => {}
120 _ => return error!(EOPNOTSUPP),
121 }
122 let mut inner = self.lock();
123 let queue = match &mut inner.state {
124 VsockSocketState::Listening(queue) => queue,
125 _ => return error!(EINVAL),
126 };
127 let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
128 Ok(socket)
129 }
130
131 fn bind(
132 &self,
133 _locked: &mut Locked<FileOpsCore>,
134 _socket: &Socket,
135 _current_task: &CurrentTask,
136 socket_address: SocketAddress,
137 ) -> Result<(), Errno> {
138 match socket_address {
139 SocketAddress::Vsock { .. } => {}
140 _ => return error!(EINVAL),
141 }
142 let mut inner = self.lock();
143 if inner.address.is_some() {
144 return error!(EINVAL);
145 }
146 inner.address = Some(socket_address);
147 Ok(())
148 }
149
150 fn read(
151 &self,
152 locked: &mut Locked<FileOpsCore>,
153 _socket: &Socket,
154 current_task: &CurrentTask,
155 data: &mut dyn OutputBuffer,
156 _flags: SocketMessageFlags,
157 ) -> Result<MessageReadInfo, Errno> {
158 let (address, file) = {
159 let inner = self.lock();
160 let address = inner.address.clone();
161
162 match &inner.state {
163 VsockSocketState::Connected { file, .. } => (address, file.clone()),
164 _ => return error!(EBADF),
165 }
166 };
167 let bytes_read = current_task
168 .override_creds(Credentials::root(), || file.read(locked, current_task, data))?;
169 Ok(MessageReadInfo {
170 bytes_read,
171 message_length: bytes_read,
172 address,
173 ancillary_data: vec![],
174 })
175 }
176
177 fn write(
178 &self,
179 locked: &mut Locked<FileOpsCore>,
180 _socket: &Socket,
181 current_task: &CurrentTask,
182 data: &mut dyn InputBuffer,
183 _dest_address: &mut Option<SocketAddress>,
184 _ancillary_data: &mut Vec<AncillaryData>,
185 ) -> Result<usize, Errno> {
186 let file = {
187 let inner = self.lock();
188 match &inner.state {
189 VsockSocketState::Connected { file, .. } => file.clone(),
190 _ => return error!(EBADF),
191 }
192 };
193 current_task.override_creds(Credentials::root(), || file.write(locked, current_task, data))
194 }
195
196 fn wait_async(
197 &self,
198 locked: &mut Locked<FileOpsCore>,
199 _socket: &Socket,
200 current_task: &CurrentTask,
201 waiter: &Waiter,
202 events: FdEvents,
203 handler: EventHandler,
204 ) -> WaitCanceler {
205 let inner = self.lock();
206 match &inner.state {
207 VsockSocketState::Connected { file, .. } => file
208 .wait_async(locked, current_task, waiter, events, handler)
209 .expect("vsock socket should be connected to a file that can be waited on"),
210 _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
211 }
212 }
213
214 fn query_events(
215 &self,
216 locked: &mut Locked<FileOpsCore>,
217 _socket: &Socket,
218 current_task: &CurrentTask,
219 ) -> Result<FdEvents, Errno> {
220 self.lock().query_events(locked, current_task)
221 }
222
223 fn shutdown(
224 &self,
225 _locked: &mut Locked<FileOpsCore>,
226 _socket: &Socket,
227 _how: SocketShutdownFlags,
228 ) -> Result<(), Errno> {
229 self.lock().state = VsockSocketState::Closed;
230 Ok(())
231 }
232
233 fn close(
234 &self,
235 locked: &mut Locked<FileOpsCore>,
236 _current_task: &CurrentTask,
237 socket: &Socket,
238 ) {
239 self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
241 .unwrap();
242 }
243
244 fn getsockname(
245 &self,
246 _locked: &mut Locked<FileOpsCore>,
247 socket: &Socket,
248 ) -> Result<SocketAddress, Errno> {
249 let inner = self.lock();
250 if let Some(address) = &inner.address {
251 Ok(address.clone())
252 } else {
253 Ok(SocketAddress::default_for_domain(socket.domain))
254 }
255 }
256
257 fn getpeername(
258 &self,
259 _locked: &mut Locked<FileOpsCore>,
260 _socket: &Socket,
261 ) -> Result<SocketAddress, Errno> {
262 let inner = self.lock();
263 match &inner.state {
264 VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
265 _ => {
266 error!(ENOTCONN)
267 }
268 }
269 }
270}
271
272impl VsockSocket {
273 pub fn remote_connection<L>(
274 &self,
275 locked: &mut Locked<L>,
276 socket: &Socket,
277 current_task: &CurrentTask,
278 file: FileHandle,
279 ) -> Result<(), Errno>
280 where
281 L: LockEqualOrBefore<FileOpsCore>,
282 {
283 assert!(file.flags().contains(OpenFlags::NONBLOCK));
286 if socket.socket_type != SocketType::Stream {
287 return error!(ENOTSUP);
288 }
289 if socket.domain != SocketDomain::Vsock {
290 return error!(EINVAL);
291 }
292
293 let mut inner = self.lock();
294 match &mut inner.state {
295 VsockSocketState::Listening(queue) => {
296 if queue.sockets.len() >= queue.backlog {
297 return error!(EAGAIN);
298 }
299 let remote_socket = Socket::new(
300 locked,
301 current_task,
302 SocketDomain::Vsock,
303 SocketType::Stream,
304 SocketProtocol::default(),
305 true,
306 )?;
307 downcast_socket_to_vsock(&remote_socket).lock().state =
308 VsockSocketState::Connected {
309 file,
310 peer_addr: SocketAddress::Vsock {
311 port: u32::MAX,
312 cid: starnix_uapi::VMADDR_CID_HOST,
313 },
314 };
315 queue.sockets.push_back(remote_socket);
316 inner.waiters.notify_fd_events(FdEvents::POLLIN);
317 Ok(())
318 }
319 _ => error!(EINVAL),
320 }
321 }
322}
323
324impl VsockSocketInner {
325 fn query_events<L>(
326 &self,
327 locked: &mut Locked<L>,
328 current_task: &CurrentTask,
329 ) -> Result<FdEvents, Errno>
330 where
331 L: LockEqualOrBefore<FileOpsCore>,
332 {
333 Ok(match &self.state {
334 VsockSocketState::Disconnected => FdEvents::empty(),
335 VsockSocketState::Connected { file, .. } => current_task
336 .override_creds(Credentials::root(), || file.query_events(locked, current_task))?,
337 VsockSocketState::Listening(queue) => {
338 if !queue.sockets.is_empty() {
339 FdEvents::POLLIN
340 } else {
341 FdEvents::empty()
342 }
343 }
344 VsockSocketState::Closed => FdEvents::POLLHUP,
345 })
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::fs::fuchsia::create_fuchsia_pipe;
353 use crate::mm::PAGE_SIZE;
354 use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
355 use crate::testing::spawn_kernel_and_run;
356 use crate::vfs::EpollFileObject;
357 use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
358 use crate::vfs::socket::SocketFile;
359 use futures::executor::block_on;
360 use starnix_sync::Unlocked;
361 use starnix_uapi::vfs::EpollEvent;
362 use syncio::Zxio;
363
364 #[::fuchsia::test]
365 async fn test_vsock_socket() {
366 spawn_kernel_and_run(async |locked, current_task| {
367 let (fs1, fs2) = fidl::Socket::create_stream();
368 const VSOCK_PORT: u32 = 5555;
369
370 let listen_socket = Socket::new(
371 locked,
372 ¤t_task,
373 SocketDomain::Vsock,
374 SocketType::Stream,
375 SocketProtocol::default(),
376 false,
377 )
378 .expect("Failed to create socket.");
379 let vsock_ns = current_task.live().abstract_vsock_namespace.clone();
380 vsock_ns
381 .bind(locked, ¤t_task, VSOCK_PORT, &listen_socket)
382 .expect("Failed to bind socket.");
383 listen_socket.listen(locked, ¤t_task, 10).expect("Failed to listen.");
384
385 let listen_socket =
386 vsock_ns.lookup(&VSOCK_PORT).expect("Failed to look up listening socket.");
387 let remote = create_fuchsia_pipe(
388 locked,
389 ¤t_task,
390 fs2,
391 OpenFlags::RDWR | OpenFlags::NONBLOCK,
392 )
393 .unwrap();
394 listen_socket
395 .downcast_socket::<VsockSocket>()
396 .unwrap()
397 .remote_connection(locked, &listen_socket, ¤t_task, remote)
398 .unwrap();
399
400 let server_socket = listen_socket.accept(locked, ¤t_task).unwrap();
401
402 let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
403 assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
404 let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
405 let read_message_info = server_socket
406 .read(locked, ¤t_task, &mut buffer_iterator, SocketMessageFlags::empty())
407 .unwrap();
408 assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
409 assert_eq!(buffer_iterator.data(), test_bytes_in);
410
411 let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
412 let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
413 server_socket
414 .write(locked, ¤t_task, &mut buffer_iterator, &mut None, &mut vec![])
415 .unwrap();
416 assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
417
418 let mut read_back_buf = [0u8; 100];
419 assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
420 assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
421
422 server_socket.close(locked, ¤t_task);
423 listen_socket.close(locked, ¤t_task);
424 })
425 .await;
426 }
427
428 #[::fuchsia::test]
429 async fn test_vsock_write_while_read() {
430 spawn_kernel_and_run(async |locked, current_task| {
431 let kernel = current_task.kernel();
432 let (fs1, fs2) = fidl::Socket::create_stream();
433 let socket = Socket::new(
434 locked,
435 ¤t_task,
436 SocketDomain::Vsock,
437 SocketType::Stream,
438 SocketProtocol::default(),
439 false,
440 )
441 .expect("Failed to create socket.");
442 let remote = create_fuchsia_pipe(
443 locked,
444 ¤t_task,
445 fs2,
446 OpenFlags::RDWR | OpenFlags::NONBLOCK,
447 )
448 .unwrap();
449 downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
450 file: remote,
451 peer_addr: SocketAddress::Vsock {
452 port: u32::MAX,
453 cid: starnix_uapi::VMADDR_CID_HOST,
454 },
455 };
456 let socket_file =
457 SocketFile::from_socket(locked, ¤t_task, socket, OpenFlags::RDWR, false)
458 .expect("Failed to create socket file.");
459
460 const XFER_SIZE: usize = 42;
461
462 let socket_clone = socket_file.clone();
463 let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
464 let bytes_read = socket_clone
465 .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
466 .unwrap();
467 assert_eq!(XFER_SIZE, bytes_read);
468 };
469 let (result, req) =
470 SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
471 kernel.kthreads.spawner().spawn_from_request(req);
472
473 std::thread::sleep(std::time::Duration::from_secs(2));
475
476 socket_file
477 .write(locked, ¤t_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
478 .unwrap();
479
480 let mut buffer = [0u8; 1024];
481 assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
482 assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
483 block_on(result).unwrap();
484 })
485 .await;
486 }
487
488 #[::fuchsia::test]
489 async fn test_vsock_poll() {
490 spawn_kernel_and_run(async |locked, current_task| {
491 let (client, server) = zx::Socket::create_stream();
492 let pipe = create_fuchsia_pipe(locked, ¤t_task, client, OpenFlags::RDWR)
493 .expect("create_fuchsia_pipe");
494 let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
495 let socket_object = Socket::new(
496 locked,
497 ¤t_task,
498 SocketDomain::Vsock,
499 SocketType::Stream,
500 SocketProtocol::default(),
501 false,
502 )
503 .expect("Failed to create socket.");
504 downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
505 file: pipe,
506 peer_addr: SocketAddress::Vsock {
507 port: u32::MAX,
508 cid: starnix_uapi::VMADDR_CID_HOST,
509 },
510 };
511 let socket = SocketFile::from_socket(
512 locked,
513 ¤t_task,
514 socket_object.clone(),
515 OpenFlags::RDWR,
516 false,
517 )
518 .expect("Failed to create socket file.");
519
520 assert_eq!(
521 socket.query_events(locked, ¤t_task),
522 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
523 );
524
525 let epoll_object = EpollFileObject::new_file(locked, ¤t_task);
526 let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
527 let event = EpollEvent::new(FdEvents::POLLIN, 0);
528 epoll_file
529 .add(locked, ¤t_task, &socket, &epoll_object, event)
530 .expect("poll_file.add");
531
532 let fds = epoll_file
533 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
534 .expect("wait");
535 assert!(fds.is_empty());
536
537 assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
538
539 assert_eq!(
540 socket.query_events(locked, ¤t_task),
541 Ok(FdEvents::POLLOUT
542 | FdEvents::POLLWRNORM
543 | FdEvents::POLLIN
544 | FdEvents::POLLRDNORM)
545 );
546 let fds = epoll_file
547 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
548 .expect("wait");
549 assert_eq!(fds.len(), 1);
550
551 assert_eq!(
552 socket.read(locked, ¤t_task, &mut VecOutputBuffer::new(64)).expect("read"),
553 1
554 );
555
556 assert_eq!(
557 socket.query_events(locked, ¤t_task),
558 Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
559 );
560 let fds = epoll_file
561 .wait(locked, ¤t_task, 1, zx::MonotonicInstant::ZERO)
562 .expect("wait");
563 assert!(fds.is_empty());
564 })
565 .await;
566 }
567}