Skip to main content

starnix_core/vfs/socket/
socket_vsock.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::task::{CurrentTask, EventHandler, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9    AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10    SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::auth::Credentials;
14use starnix_uapi::errors::Errno;
15use starnix_uapi::open_flags::OpenFlags;
16use starnix_uapi::vfs::FdEvents;
17use starnix_uapi::{errno, error, ucred};
18
19// An implementation of AF_VSOCK.
20// See https://man7.org/linux/man-pages/man7/vsock.7.html
21
22pub struct VsockSocket {
23    inner: Mutex<VsockSocketInner>,
24}
25
26struct VsockSocketInner {
27    /// The address that this socket has been bound to, if it has been bound.
28    address: Option<SocketAddress>,
29
30    // WaitQueue for listening sockets.
31    waiters: WaitQueue,
32
33    // state of the vsock. Contains a handle to a ZxioBackedSocket when connected.
34    state: VsockSocketState,
35}
36
37enum VsockSocketState {
38    /// The socket has not been connected.
39    Disconnected,
40
41    /// The socket has had `listen` called and can accept incoming connections.
42    Listening(AcceptQueue),
43
44    /// The socket is connected to a ZxioBackedSocket.
45    Connected { file: FileHandle, peer_addr: SocketAddress },
46
47    /// The socket is closed.
48    Closed,
49}
50
51fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
52    // It is a programing error if we are downcasting
53    // a different type of socket as sockets from different families
54    // should not communicate, so unwrapping here
55    // will let us know that.
56    socket.downcast_socket::<VsockSocket>().unwrap()
57}
58
59impl VsockSocket {
60    pub fn new(_socket_type: SocketType) -> VsockSocket {
61        VsockSocket {
62            inner: Mutex::new(VsockSocketInner {
63                address: None,
64                waiters: WaitQueue::default(),
65                state: VsockSocketState::Disconnected,
66            }),
67        }
68    }
69
70    /// Locks and returns the inner state of the Socket.
71    fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
72        self.inner.lock()
73    }
74}
75
76impl SocketOps for VsockSocket {
77    // Connect with Vsock sockets is not allowed as
78    // we only connect from the enclosing OK.
79    fn connect(
80        &self,
81        _locked: &mut Locked<FileOpsCore>,
82        _socket: &SocketHandle,
83        _current_task: &CurrentTask,
84        _peer: SocketPeer,
85    ) -> Result<(), Errno> {
86        error!(EPROTOTYPE)
87    }
88
89    fn listen(
90        &self,
91        _locked: &mut Locked<FileOpsCore>,
92        _socket: &Socket,
93        backlog: i32,
94        _credentials: ucred,
95    ) -> Result<(), Errno> {
96        let mut inner = self.lock();
97        let is_bound = inner.address.is_some();
98        let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
99        match &mut inner.state {
100            VsockSocketState::Disconnected if is_bound => {
101                inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
102                Ok(())
103            }
104            VsockSocketState::Listening(queue) => {
105                queue.set_backlog(backlog)?;
106                Ok(())
107            }
108            _ => error!(EINVAL),
109        }
110    }
111
112    fn accept(
113        &self,
114        _locked: &mut Locked<FileOpsCore>,
115        socket: &Socket,
116        _current_task: &CurrentTask,
117    ) -> Result<SocketHandle, Errno> {
118        match socket.socket_type {
119            SocketType::Stream | SocketType::SeqPacket => {}
120            _ => return error!(EOPNOTSUPP),
121        }
122        let mut inner = self.lock();
123        let queue = match &mut inner.state {
124            VsockSocketState::Listening(queue) => queue,
125            _ => return error!(EINVAL),
126        };
127        let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
128        Ok(socket)
129    }
130
131    fn bind(
132        &self,
133        _locked: &mut Locked<FileOpsCore>,
134        _socket: &Socket,
135        _current_task: &CurrentTask,
136        socket_address: SocketAddress,
137    ) -> Result<(), Errno> {
138        match socket_address {
139            SocketAddress::Vsock { .. } => {}
140            _ => return error!(EINVAL),
141        }
142        let mut inner = self.lock();
143        if inner.address.is_some() {
144            return error!(EINVAL);
145        }
146        inner.address = Some(socket_address);
147        Ok(())
148    }
149
150    fn read(
151        &self,
152        locked: &mut Locked<FileOpsCore>,
153        _socket: &Socket,
154        current_task: &CurrentTask,
155        data: &mut dyn OutputBuffer,
156        _flags: SocketMessageFlags,
157    ) -> Result<MessageReadInfo, Errno> {
158        let (address, file) = {
159            let inner = self.lock();
160            let address = inner.address.clone();
161
162            match &inner.state {
163                VsockSocketState::Connected { file, .. } => (address, file.clone()),
164                _ => return error!(EBADF),
165            }
166        };
167        let bytes_read = current_task
168            .override_creds(Credentials::root(), || file.read(locked, current_task, data))?;
169        Ok(MessageReadInfo {
170            bytes_read,
171            message_length: bytes_read,
172            address,
173            ancillary_data: vec![],
174        })
175    }
176
177    fn write(
178        &self,
179        locked: &mut Locked<FileOpsCore>,
180        _socket: &Socket,
181        current_task: &CurrentTask,
182        data: &mut dyn InputBuffer,
183        _dest_address: &mut Option<SocketAddress>,
184        _ancillary_data: &mut Vec<AncillaryData>,
185    ) -> Result<usize, Errno> {
186        let file = {
187            let inner = self.lock();
188            match &inner.state {
189                VsockSocketState::Connected { file, .. } => file.clone(),
190                _ => return error!(EBADF),
191            }
192        };
193        current_task.override_creds(Credentials::root(), || file.write(locked, current_task, data))
194    }
195
196    fn wait_async(
197        &self,
198        locked: &mut Locked<FileOpsCore>,
199        _socket: &Socket,
200        current_task: &CurrentTask,
201        waiter: &Waiter,
202        events: FdEvents,
203        handler: EventHandler,
204    ) -> WaitCanceler {
205        let inner = self.lock();
206        match &inner.state {
207            VsockSocketState::Connected { file, .. } => file
208                .wait_async(locked, current_task, waiter, events, handler)
209                .expect("vsock socket should be connected to a file that can be waited on"),
210            _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
211        }
212    }
213
214    fn query_events(
215        &self,
216        locked: &mut Locked<FileOpsCore>,
217        _socket: &Socket,
218        current_task: &CurrentTask,
219    ) -> Result<FdEvents, Errno> {
220        self.lock().query_events(locked, current_task)
221    }
222
223    fn shutdown(
224        &self,
225        _locked: &mut Locked<FileOpsCore>,
226        _socket: &Socket,
227        _how: SocketShutdownFlags,
228    ) -> Result<(), Errno> {
229        self.lock().state = VsockSocketState::Closed;
230        Ok(())
231    }
232
233    fn close(
234        &self,
235        locked: &mut Locked<FileOpsCore>,
236        _current_task: &CurrentTask,
237        socket: &Socket,
238    ) {
239        // Call to shutdown should never fail, so unwrap is OK
240        self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
241            .unwrap();
242    }
243
244    fn getsockname(
245        &self,
246        _locked: &mut Locked<FileOpsCore>,
247        socket: &Socket,
248    ) -> Result<SocketAddress, Errno> {
249        let inner = self.lock();
250        if let Some(address) = &inner.address {
251            Ok(address.clone())
252        } else {
253            Ok(SocketAddress::default_for_domain(socket.domain))
254        }
255    }
256
257    fn getpeername(
258        &self,
259        _locked: &mut Locked<FileOpsCore>,
260        _socket: &Socket,
261    ) -> Result<SocketAddress, Errno> {
262        let inner = self.lock();
263        match &inner.state {
264            VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
265            _ => {
266                error!(ENOTCONN)
267            }
268        }
269    }
270}
271
272impl VsockSocket {
273    pub fn remote_connection<L>(
274        &self,
275        locked: &mut Locked<L>,
276        socket: &Socket,
277        current_task: &CurrentTask,
278        file: FileHandle,
279    ) -> Result<(), Errno>
280    where
281        L: LockEqualOrBefore<FileOpsCore>,
282    {
283        // we only allow non-blocking files here, so that
284        // read and write on file can return EAGAIN.
285        assert!(file.flags().contains(OpenFlags::NONBLOCK));
286        if socket.socket_type != SocketType::Stream {
287            return error!(ENOTSUP);
288        }
289        if socket.domain != SocketDomain::Vsock {
290            return error!(EINVAL);
291        }
292
293        let mut inner = self.lock();
294        match &mut inner.state {
295            VsockSocketState::Listening(queue) => {
296                if queue.sockets.len() >= queue.backlog {
297                    return error!(EAGAIN);
298                }
299                let remote_socket = Socket::new(
300                    locked,
301                    current_task,
302                    SocketDomain::Vsock,
303                    SocketType::Stream,
304                    SocketProtocol::default(),
305                    /* kernel_private = */ true,
306                )?;
307                downcast_socket_to_vsock(&remote_socket).lock().state =
308                    VsockSocketState::Connected {
309                        file,
310                        peer_addr: SocketAddress::Vsock {
311                            port: u32::MAX,
312                            cid: starnix_uapi::VMADDR_CID_HOST,
313                        },
314                    };
315                queue.sockets.push_back(remote_socket);
316                inner.waiters.notify_fd_events(FdEvents::POLLIN);
317                Ok(())
318            }
319            _ => error!(EINVAL),
320        }
321    }
322}
323
324impl VsockSocketInner {
325    fn query_events<L>(
326        &self,
327        locked: &mut Locked<L>,
328        current_task: &CurrentTask,
329    ) -> Result<FdEvents, Errno>
330    where
331        L: LockEqualOrBefore<FileOpsCore>,
332    {
333        Ok(match &self.state {
334            VsockSocketState::Disconnected => FdEvents::empty(),
335            VsockSocketState::Connected { file, .. } => current_task
336                .override_creds(Credentials::root(), || file.query_events(locked, current_task))?,
337            VsockSocketState::Listening(queue) => {
338                if !queue.sockets.is_empty() {
339                    FdEvents::POLLIN
340                } else {
341                    FdEvents::empty()
342                }
343            }
344            VsockSocketState::Closed => FdEvents::POLLHUP,
345        })
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::fs::fuchsia::create_fuchsia_pipe;
353    use crate::mm::PAGE_SIZE;
354    use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
355    use crate::testing::spawn_kernel_and_run;
356    use crate::vfs::EpollFileObject;
357    use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
358    use crate::vfs::socket::SocketFile;
359    use futures::executor::block_on;
360    use starnix_sync::Unlocked;
361    use starnix_uapi::vfs::EpollEvent;
362    use syncio::Zxio;
363    use zx::HandleBased;
364
365    #[::fuchsia::test]
366    async fn test_vsock_socket() {
367        spawn_kernel_and_run(async |locked, current_task| {
368            let (fs1, fs2) = fidl::Socket::create_stream();
369            const VSOCK_PORT: u32 = 5555;
370
371            let listen_socket = Socket::new(
372                locked,
373                &current_task,
374                SocketDomain::Vsock,
375                SocketType::Stream,
376                SocketProtocol::default(),
377                /* kernel_private = */ false,
378            )
379            .expect("Failed to create socket.");
380            let vsock_ns = current_task.live().abstract_vsock_namespace.clone();
381            vsock_ns
382                .bind(locked, &current_task, VSOCK_PORT, &listen_socket)
383                .expect("Failed to bind socket.");
384            listen_socket.listen(locked, &current_task, 10).expect("Failed to listen.");
385
386            let listen_socket =
387                vsock_ns.lookup(&VSOCK_PORT).expect("Failed to look up listening socket.");
388            let remote = create_fuchsia_pipe(
389                locked,
390                &current_task,
391                fs2,
392                OpenFlags::RDWR | OpenFlags::NONBLOCK,
393            )
394            .unwrap();
395            listen_socket
396                .downcast_socket::<VsockSocket>()
397                .unwrap()
398                .remote_connection(locked, &listen_socket, &current_task, remote)
399                .unwrap();
400
401            let server_socket = listen_socket.accept(locked, &current_task).unwrap();
402
403            let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
404            assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
405            let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
406            let read_message_info = server_socket
407                .read(locked, &current_task, &mut buffer_iterator, SocketMessageFlags::empty())
408                .unwrap();
409            assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
410            assert_eq!(buffer_iterator.data(), test_bytes_in);
411
412            let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
413            let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
414            server_socket
415                .write(locked, &current_task, &mut buffer_iterator, &mut None, &mut vec![])
416                .unwrap();
417            assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
418
419            let mut read_back_buf = [0u8; 100];
420            assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
421            assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
422
423            server_socket.close(locked, &current_task);
424            listen_socket.close(locked, &current_task);
425        })
426        .await;
427    }
428
429    #[::fuchsia::test]
430    async fn test_vsock_write_while_read() {
431        spawn_kernel_and_run(async |locked, current_task| {
432            let kernel = current_task.kernel();
433            let (fs1, fs2) = fidl::Socket::create_stream();
434            let socket = Socket::new(
435                locked,
436                &current_task,
437                SocketDomain::Vsock,
438                SocketType::Stream,
439                SocketProtocol::default(),
440                /* kernel_private = */ false,
441            )
442            .expect("Failed to create socket.");
443            let remote = create_fuchsia_pipe(
444                locked,
445                &current_task,
446                fs2,
447                OpenFlags::RDWR | OpenFlags::NONBLOCK,
448            )
449            .unwrap();
450            downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
451                file: remote,
452                peer_addr: SocketAddress::Vsock {
453                    port: u32::MAX,
454                    cid: starnix_uapi::VMADDR_CID_HOST,
455                },
456            };
457            let socket_file =
458                SocketFile::from_socket(locked, &current_task, socket, OpenFlags::RDWR, false)
459                    .expect("Failed to create socket file.");
460
461            const XFER_SIZE: usize = 42;
462
463            let socket_clone = socket_file.clone();
464            let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
465                let bytes_read = socket_clone
466                    .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
467                    .unwrap();
468                assert_eq!(XFER_SIZE, bytes_read);
469            };
470            let (result, req) =
471                SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
472            kernel.kthreads.spawner().spawn_from_request(req);
473
474            // Wait for the thread to become blocked on the read.
475            std::thread::sleep(std::time::Duration::from_secs(2));
476
477            socket_file
478                .write(locked, &current_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
479                .unwrap();
480
481            let mut buffer = [0u8; 1024];
482            assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
483            assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
484            block_on(result).unwrap();
485        })
486        .await;
487    }
488
489    #[::fuchsia::test]
490    async fn test_vsock_poll() {
491        spawn_kernel_and_run(async |locked, current_task| {
492            let (client, server) = zx::Socket::create_stream();
493            let pipe = create_fuchsia_pipe(locked, &current_task, client, OpenFlags::RDWR)
494                .expect("create_fuchsia_pipe");
495            let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
496            let socket_object = Socket::new(
497                locked,
498                &current_task,
499                SocketDomain::Vsock,
500                SocketType::Stream,
501                SocketProtocol::default(),
502                /* kernel_private = */ false,
503            )
504            .expect("Failed to create socket.");
505            downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
506                file: pipe,
507                peer_addr: SocketAddress::Vsock {
508                    port: u32::MAX,
509                    cid: starnix_uapi::VMADDR_CID_HOST,
510                },
511            };
512            let socket = SocketFile::from_socket(
513                locked,
514                &current_task,
515                socket_object.clone(),
516                OpenFlags::RDWR,
517                false,
518            )
519            .expect("Failed to create socket file.");
520
521            assert_eq!(
522                socket.query_events(locked, &current_task),
523                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
524            );
525
526            let epoll_object = EpollFileObject::new_file(locked, &current_task);
527            let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
528            let event = EpollEvent::new(FdEvents::POLLIN, 0);
529            epoll_file
530                .add(locked, &current_task, &socket, &epoll_object, event)
531                .expect("poll_file.add");
532
533            let fds = epoll_file
534                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
535                .expect("wait");
536            assert!(fds.is_empty());
537
538            assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
539
540            assert_eq!(
541                socket.query_events(locked, &current_task),
542                Ok(FdEvents::POLLOUT
543                    | FdEvents::POLLWRNORM
544                    | FdEvents::POLLIN
545                    | FdEvents::POLLRDNORM)
546            );
547            let fds = epoll_file
548                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
549                .expect("wait");
550            assert_eq!(fds.len(), 1);
551
552            assert_eq!(
553                socket.read(locked, &current_task, &mut VecOutputBuffer::new(64)).expect("read"),
554                1
555            );
556
557            assert_eq!(
558                socket.query_events(locked, &current_task),
559                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
560            );
561            let fds = epoll_file
562                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
563                .expect("wait");
564            assert!(fds.is_empty());
565        })
566        .await;
567    }
568}