starnix_core/vfs/socket/
socket_vsock.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::task::{CurrentTask, EventHandler, FullCredentials, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9    AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10    SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::errors::Errno;
14use starnix_uapi::open_flags::OpenFlags;
15use starnix_uapi::vfs::FdEvents;
16use starnix_uapi::{errno, error, ucred};
17
18// An implementation of AF_VSOCK.
19// See https://man7.org/linux/man-pages/man7/vsock.7.html
20
21pub struct VsockSocket {
22    inner: Mutex<VsockSocketInner>,
23}
24
25struct VsockSocketInner {
26    /// The address that this socket has been bound to, if it has been bound.
27    address: Option<SocketAddress>,
28
29    // WaitQueue for listening sockets.
30    waiters: WaitQueue,
31
32    // state of the vsock. Contains a handle to a ZxioBackedSocket when connected.
33    state: VsockSocketState,
34}
35
36enum VsockSocketState {
37    /// The socket has not been connected.
38    Disconnected,
39
40    /// The socket has had `listen` called and can accept incoming connections.
41    Listening(AcceptQueue),
42
43    /// The socket is connected to a ZxioBackedSocket.
44    Connected { file: FileHandle, peer_addr: SocketAddress },
45
46    /// The socket is closed.
47    Closed,
48}
49
50fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
51    // It is a programing error if we are downcasting
52    // a different type of socket as sockets from different families
53    // should not communicate, so unwrapping here
54    // will let us know that.
55    socket.downcast_socket::<VsockSocket>().unwrap()
56}
57
58impl VsockSocket {
59    pub fn new(_socket_type: SocketType) -> VsockSocket {
60        VsockSocket {
61            inner: Mutex::new(VsockSocketInner {
62                address: None,
63                waiters: WaitQueue::default(),
64                state: VsockSocketState::Disconnected,
65            }),
66        }
67    }
68
69    /// Locks and returns the inner state of the Socket.
70    fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
71        self.inner.lock()
72    }
73}
74
75impl SocketOps for VsockSocket {
76    // Connect with Vsock sockets is not allowed as
77    // we only connect from the enclosing OK.
78    fn connect(
79        &self,
80        _locked: &mut Locked<FileOpsCore>,
81        _socket: &SocketHandle,
82        _current_task: &CurrentTask,
83        _peer: SocketPeer,
84    ) -> Result<(), Errno> {
85        error!(EPROTOTYPE)
86    }
87
88    fn listen(
89        &self,
90        _locked: &mut Locked<FileOpsCore>,
91        _socket: &Socket,
92        backlog: i32,
93        _credentials: ucred,
94    ) -> Result<(), Errno> {
95        let mut inner = self.lock();
96        let is_bound = inner.address.is_some();
97        let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
98        match &mut inner.state {
99            VsockSocketState::Disconnected if is_bound => {
100                inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
101                Ok(())
102            }
103            VsockSocketState::Listening(queue) => {
104                queue.set_backlog(backlog)?;
105                Ok(())
106            }
107            _ => error!(EINVAL),
108        }
109    }
110
111    fn accept(
112        &self,
113        _locked: &mut Locked<FileOpsCore>,
114        socket: &Socket,
115        _current_task: &CurrentTask,
116    ) -> Result<SocketHandle, Errno> {
117        match socket.socket_type {
118            SocketType::Stream | SocketType::SeqPacket => {}
119            _ => return error!(EOPNOTSUPP),
120        }
121        let mut inner = self.lock();
122        let queue = match &mut inner.state {
123            VsockSocketState::Listening(queue) => queue,
124            _ => return error!(EINVAL),
125        };
126        let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
127        Ok(socket)
128    }
129
130    fn bind(
131        &self,
132        _locked: &mut Locked<FileOpsCore>,
133        _socket: &Socket,
134        _current_task: &CurrentTask,
135        socket_address: SocketAddress,
136    ) -> Result<(), Errno> {
137        match socket_address {
138            SocketAddress::Vsock { .. } => {}
139            _ => return error!(EINVAL),
140        }
141        let mut inner = self.lock();
142        if inner.address.is_some() {
143            return error!(EINVAL);
144        }
145        inner.address = Some(socket_address);
146        Ok(())
147    }
148
149    fn read(
150        &self,
151        locked: &mut Locked<FileOpsCore>,
152        _socket: &Socket,
153        current_task: &CurrentTask,
154        data: &mut dyn OutputBuffer,
155        _flags: SocketMessageFlags,
156    ) -> Result<MessageReadInfo, Errno> {
157        let (address, file) = {
158            let inner = self.lock();
159            let address = inner.address.clone();
160
161            match &inner.state {
162                VsockSocketState::Connected { file, .. } => (address, file.clone()),
163                _ => return error!(EBADF),
164            }
165        };
166        let bytes_read = current_task.override_creds(
167            |creds| *creds = FullCredentials::for_kernel(),
168            || file.read(locked, current_task, data),
169        )?;
170        Ok(MessageReadInfo {
171            bytes_read,
172            message_length: bytes_read,
173            address,
174            ancillary_data: vec![],
175        })
176    }
177
178    fn write(
179        &self,
180        locked: &mut Locked<FileOpsCore>,
181        _socket: &Socket,
182        current_task: &CurrentTask,
183        data: &mut dyn InputBuffer,
184        _dest_address: &mut Option<SocketAddress>,
185        _ancillary_data: &mut Vec<AncillaryData>,
186    ) -> Result<usize, Errno> {
187        let file = {
188            let inner = self.lock();
189            match &inner.state {
190                VsockSocketState::Connected { file, .. } => file.clone(),
191                _ => return error!(EBADF),
192            }
193        };
194        current_task.override_creds(
195            |creds| *creds = FullCredentials::for_kernel(),
196            || file.write(locked, current_task, data),
197        )
198    }
199
200    fn wait_async(
201        &self,
202        locked: &mut Locked<FileOpsCore>,
203        _socket: &Socket,
204        current_task: &CurrentTask,
205        waiter: &Waiter,
206        events: FdEvents,
207        handler: EventHandler,
208    ) -> WaitCanceler {
209        let inner = self.lock();
210        match &inner.state {
211            VsockSocketState::Connected { file, .. } => file
212                .wait_async(locked, current_task, waiter, events, handler)
213                .expect("vsock socket should be connected to a file that can be waited on"),
214            _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
215        }
216    }
217
218    fn query_events(
219        &self,
220        locked: &mut Locked<FileOpsCore>,
221        _socket: &Socket,
222        current_task: &CurrentTask,
223    ) -> Result<FdEvents, Errno> {
224        self.lock().query_events(locked, current_task)
225    }
226
227    fn shutdown(
228        &self,
229        _locked: &mut Locked<FileOpsCore>,
230        _socket: &Socket,
231        _how: SocketShutdownFlags,
232    ) -> Result<(), Errno> {
233        self.lock().state = VsockSocketState::Closed;
234        Ok(())
235    }
236
237    fn close(
238        &self,
239        locked: &mut Locked<FileOpsCore>,
240        _current_task: &CurrentTask,
241        socket: &Socket,
242    ) {
243        // Call to shutdown should never fail, so unwrap is OK
244        self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
245            .unwrap();
246    }
247
248    fn getsockname(
249        &self,
250        _locked: &mut Locked<FileOpsCore>,
251        socket: &Socket,
252    ) -> Result<SocketAddress, Errno> {
253        let inner = self.lock();
254        if let Some(address) = &inner.address {
255            Ok(address.clone())
256        } else {
257            Ok(SocketAddress::default_for_domain(socket.domain))
258        }
259    }
260
261    fn getpeername(
262        &self,
263        _locked: &mut Locked<FileOpsCore>,
264        _socket: &Socket,
265    ) -> Result<SocketAddress, Errno> {
266        let inner = self.lock();
267        match &inner.state {
268            VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
269            _ => {
270                error!(ENOTCONN)
271            }
272        }
273    }
274}
275
276impl VsockSocket {
277    pub fn remote_connection<L>(
278        &self,
279        locked: &mut Locked<L>,
280        socket: &Socket,
281        current_task: &CurrentTask,
282        file: FileHandle,
283    ) -> Result<(), Errno>
284    where
285        L: LockEqualOrBefore<FileOpsCore>,
286    {
287        // we only allow non-blocking files here, so that
288        // read and write on file can return EAGAIN.
289        assert!(file.flags().contains(OpenFlags::NONBLOCK));
290        if socket.socket_type != SocketType::Stream {
291            return error!(ENOTSUP);
292        }
293        if socket.domain != SocketDomain::Vsock {
294            return error!(EINVAL);
295        }
296
297        let mut inner = self.lock();
298        match &mut inner.state {
299            VsockSocketState::Listening(queue) => {
300                if queue.sockets.len() >= queue.backlog {
301                    return error!(EAGAIN);
302                }
303                let remote_socket = Socket::new(
304                    locked,
305                    current_task,
306                    SocketDomain::Vsock,
307                    SocketType::Stream,
308                    SocketProtocol::default(),
309                    /* kernel_private = */ false,
310                )?;
311                downcast_socket_to_vsock(&remote_socket).lock().state =
312                    VsockSocketState::Connected {
313                        file,
314                        peer_addr: SocketAddress::Vsock {
315                            port: u32::MAX,
316                            cid: starnix_uapi::VMADDR_CID_HOST,
317                        },
318                    };
319                queue.sockets.push_back(remote_socket);
320                inner.waiters.notify_fd_events(FdEvents::POLLIN);
321                Ok(())
322            }
323            _ => error!(EINVAL),
324        }
325    }
326}
327
328impl VsockSocketInner {
329    fn query_events<L>(
330        &self,
331        locked: &mut Locked<L>,
332        current_task: &CurrentTask,
333    ) -> Result<FdEvents, Errno>
334    where
335        L: LockEqualOrBefore<FileOpsCore>,
336    {
337        Ok(match &self.state {
338            VsockSocketState::Disconnected => FdEvents::empty(),
339            VsockSocketState::Connected { file, .. } => current_task.override_creds(
340                |creds| *creds = FullCredentials::for_kernel(),
341                || file.query_events(locked, current_task),
342            )?,
343            VsockSocketState::Listening(queue) => {
344                if !queue.sockets.is_empty() {
345                    FdEvents::POLLIN
346                } else {
347                    FdEvents::empty()
348                }
349            }
350            VsockSocketState::Closed => FdEvents::POLLHUP,
351        })
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::fs::fuchsia::create_fuchsia_pipe;
359    use crate::mm::PAGE_SIZE;
360    use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
361    use crate::testing::spawn_kernel_and_run;
362    use crate::vfs::EpollFileObject;
363    use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
364    use crate::vfs::socket::SocketFile;
365    use futures::executor::block_on;
366    use starnix_sync::Unlocked;
367    use starnix_uapi::vfs::EpollEvent;
368    use syncio::Zxio;
369    use zx::HandleBased;
370
371    #[::fuchsia::test]
372    async fn test_vsock_socket() {
373        spawn_kernel_and_run(async |locked, current_task| {
374            let (fs1, fs2) = fidl::Socket::create_stream();
375            const VSOCK_PORT: u32 = 5555;
376
377            let listen_socket = Socket::new(
378                locked,
379                &current_task,
380                SocketDomain::Vsock,
381                SocketType::Stream,
382                SocketProtocol::default(),
383                /* kernel_private = */ false,
384            )
385            .expect("Failed to create socket.");
386            current_task
387                .abstract_vsock_namespace
388                .bind(locked, &current_task, VSOCK_PORT, &listen_socket)
389                .expect("Failed to bind socket.");
390            listen_socket.listen(locked, &current_task, 10).expect("Failed to listen.");
391
392            let listen_socket = current_task
393                .abstract_vsock_namespace
394                .lookup(&VSOCK_PORT)
395                .expect("Failed to look up listening socket.");
396            let remote = create_fuchsia_pipe(
397                locked,
398                &current_task,
399                fs2,
400                OpenFlags::RDWR | OpenFlags::NONBLOCK,
401            )
402            .unwrap();
403            listen_socket
404                .downcast_socket::<VsockSocket>()
405                .unwrap()
406                .remote_connection(locked, &listen_socket, &current_task, remote)
407                .unwrap();
408
409            let server_socket = listen_socket.accept(locked, &current_task).unwrap();
410
411            let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
412            assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
413            let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
414            let read_message_info = server_socket
415                .read(locked, &current_task, &mut buffer_iterator, SocketMessageFlags::empty())
416                .unwrap();
417            assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
418            assert_eq!(buffer_iterator.data(), test_bytes_in);
419
420            let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
421            let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
422            server_socket
423                .write(locked, &current_task, &mut buffer_iterator, &mut None, &mut vec![])
424                .unwrap();
425            assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
426
427            let mut read_back_buf = [0u8; 100];
428            assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
429            assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
430
431            server_socket.close(locked, &current_task);
432            listen_socket.close(locked, &current_task);
433        })
434        .await;
435    }
436
437    #[::fuchsia::test]
438    async fn test_vsock_write_while_read() {
439        spawn_kernel_and_run(async |locked, current_task| {
440            let kernel = current_task.kernel();
441            let (fs1, fs2) = fidl::Socket::create_stream();
442            let socket = Socket::new(
443                locked,
444                &current_task,
445                SocketDomain::Vsock,
446                SocketType::Stream,
447                SocketProtocol::default(),
448                /* kernel_private = */ false,
449            )
450            .expect("Failed to create socket.");
451            let remote = create_fuchsia_pipe(
452                locked,
453                &current_task,
454                fs2,
455                OpenFlags::RDWR | OpenFlags::NONBLOCK,
456            )
457            .unwrap();
458            downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
459                file: remote,
460                peer_addr: SocketAddress::Vsock {
461                    port: u32::MAX,
462                    cid: starnix_uapi::VMADDR_CID_HOST,
463                },
464            };
465            let socket_file =
466                SocketFile::from_socket(locked, &current_task, socket, OpenFlags::RDWR, false)
467                    .expect("Failed to create socket file.");
468
469            const XFER_SIZE: usize = 42;
470
471            let socket_clone = socket_file.clone();
472            let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
473                let bytes_read = socket_clone
474                    .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
475                    .unwrap();
476                assert_eq!(XFER_SIZE, bytes_read);
477            };
478            let (result, req) =
479                SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
480            kernel.kthreads.spawner().spawn_from_request(req);
481
482            // Wait for the thread to become blocked on the read.
483            std::thread::sleep(std::time::Duration::from_secs(2));
484
485            socket_file
486                .write(locked, &current_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
487                .unwrap();
488
489            let mut buffer = [0u8; 1024];
490            assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
491            assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
492            block_on(result).unwrap();
493        })
494        .await;
495    }
496
497    #[::fuchsia::test]
498    async fn test_vsock_poll() {
499        spawn_kernel_and_run(async |locked, current_task| {
500            let (client, server) = zx::Socket::create_stream();
501            let pipe = create_fuchsia_pipe(locked, &current_task, client, OpenFlags::RDWR)
502                .expect("create_fuchsia_pipe");
503            let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
504            let socket_object = Socket::new(
505                locked,
506                &current_task,
507                SocketDomain::Vsock,
508                SocketType::Stream,
509                SocketProtocol::default(),
510                /* kernel_private = */ false,
511            )
512            .expect("Failed to create socket.");
513            downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
514                file: pipe,
515                peer_addr: SocketAddress::Vsock {
516                    port: u32::MAX,
517                    cid: starnix_uapi::VMADDR_CID_HOST,
518                },
519            };
520            let socket = SocketFile::from_socket(
521                locked,
522                &current_task,
523                socket_object.clone(),
524                OpenFlags::RDWR,
525                false,
526            )
527            .expect("Failed to create socket file.");
528
529            assert_eq!(
530                socket.query_events(locked, &current_task),
531                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
532            );
533
534            let epoll_object = EpollFileObject::new_file(locked, &current_task);
535            let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
536            let event = EpollEvent::new(FdEvents::POLLIN, 0);
537            epoll_file
538                .add(locked, &current_task, &socket, &epoll_object, event)
539                .expect("poll_file.add");
540
541            let fds = epoll_file
542                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
543                .expect("wait");
544            assert!(fds.is_empty());
545
546            assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
547
548            assert_eq!(
549                socket.query_events(locked, &current_task),
550                Ok(FdEvents::POLLOUT
551                    | FdEvents::POLLWRNORM
552                    | FdEvents::POLLIN
553                    | FdEvents::POLLRDNORM)
554            );
555            let fds = epoll_file
556                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
557                .expect("wait");
558            assert_eq!(fds.len(), 1);
559
560            assert_eq!(
561                socket.read(locked, &current_task, &mut VecOutputBuffer::new(64)).expect("read"),
562                1
563            );
564
565            assert_eq!(
566                socket.query_events(locked, &current_task),
567                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
568            );
569            let fds = epoll_file
570                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
571                .expect("wait");
572            assert!(fds.is_empty());
573        })
574        .await;
575    }
576}