Skip to main content

starnix_core/vfs/socket/
socket_vsock.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::task::{CurrentTask, EventHandler, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9    AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10    SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::auth::Credentials;
14use starnix_uapi::errors::Errno;
15use starnix_uapi::open_flags::OpenFlags;
16use starnix_uapi::vfs::FdEvents;
17use starnix_uapi::{errno, error, ucred};
18
19// An implementation of AF_VSOCK.
20// See https://man7.org/linux/man-pages/man7/vsock.7.html
21
22pub struct VsockSocket {
23    inner: Mutex<VsockSocketInner>,
24}
25
26struct VsockSocketInner {
27    /// The address that this socket has been bound to, if it has been bound.
28    address: Option<SocketAddress>,
29
30    // WaitQueue for listening sockets.
31    waiters: WaitQueue,
32
33    // state of the vsock. Contains a handle to a ZxioBackedSocket when connected.
34    state: VsockSocketState,
35}
36
37enum VsockSocketState {
38    /// The socket has not been connected.
39    Disconnected,
40
41    /// The socket has had `listen` called and can accept incoming connections.
42    Listening(AcceptQueue),
43
44    /// The socket is connected to a ZxioBackedSocket.
45    Connected { file: FileHandle, peer_addr: SocketAddress },
46
47    /// The socket is closed.
48    Closed,
49}
50
51fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
52    // It is a programing error if we are downcasting
53    // a different type of socket as sockets from different families
54    // should not communicate, so unwrapping here
55    // will let us know that.
56    socket.downcast_socket::<VsockSocket>().unwrap()
57}
58
59impl VsockSocket {
60    pub fn new(_socket_type: SocketType) -> VsockSocket {
61        VsockSocket {
62            inner: Mutex::new(VsockSocketInner {
63                address: None,
64                waiters: WaitQueue::default(),
65                state: VsockSocketState::Disconnected,
66            }),
67        }
68    }
69
70    /// Locks and returns the inner state of the Socket.
71    fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
72        self.inner.lock()
73    }
74}
75
76impl SocketOps for VsockSocket {
77    // Connect with Vsock sockets is not allowed as
78    // we only connect from the enclosing OK.
79    fn connect(
80        &self,
81        _locked: &mut Locked<FileOpsCore>,
82        _socket: &SocketHandle,
83        _current_task: &CurrentTask,
84        _peer: SocketPeer,
85    ) -> Result<(), Errno> {
86        error!(EPROTOTYPE)
87    }
88
89    fn listen(
90        &self,
91        _locked: &mut Locked<FileOpsCore>,
92        _socket: &Socket,
93        backlog: i32,
94        _credentials: ucred,
95    ) -> Result<(), Errno> {
96        let mut inner = self.lock();
97        let is_bound = inner.address.is_some();
98        let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
99        match &mut inner.state {
100            VsockSocketState::Disconnected if is_bound => {
101                inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
102                Ok(())
103            }
104            VsockSocketState::Listening(queue) => {
105                queue.set_backlog(backlog)?;
106                Ok(())
107            }
108            _ => error!(EINVAL),
109        }
110    }
111
112    fn accept(
113        &self,
114        _locked: &mut Locked<FileOpsCore>,
115        socket: &Socket,
116        _current_task: &CurrentTask,
117    ) -> Result<SocketHandle, Errno> {
118        match socket.socket_type {
119            SocketType::Stream | SocketType::SeqPacket => {}
120            _ => return error!(EOPNOTSUPP),
121        }
122        let mut inner = self.lock();
123        let queue = match &mut inner.state {
124            VsockSocketState::Listening(queue) => queue,
125            _ => return error!(EINVAL),
126        };
127        let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
128        Ok(socket)
129    }
130
131    fn bind(
132        &self,
133        _locked: &mut Locked<FileOpsCore>,
134        _socket: &Socket,
135        _current_task: &CurrentTask,
136        socket_address: SocketAddress,
137    ) -> Result<(), Errno> {
138        match socket_address {
139            SocketAddress::Vsock { .. } => {}
140            _ => return error!(EINVAL),
141        }
142        let mut inner = self.lock();
143        if inner.address.is_some() {
144            return error!(EINVAL);
145        }
146        inner.address = Some(socket_address);
147        Ok(())
148    }
149
150    fn read(
151        &self,
152        locked: &mut Locked<FileOpsCore>,
153        _socket: &Socket,
154        current_task: &CurrentTask,
155        data: &mut dyn OutputBuffer,
156        _flags: SocketMessageFlags,
157    ) -> Result<MessageReadInfo, Errno> {
158        let (address, file) = {
159            let inner = self.lock();
160            let address = inner.address.clone();
161
162            match &inner.state {
163                VsockSocketState::Connected { file, .. } => (address, file.clone()),
164                _ => return error!(EBADF),
165            }
166        };
167        let bytes_read = current_task
168            .override_creds(Credentials::root(), || file.read(locked, current_task, data))?;
169        Ok(MessageReadInfo {
170            bytes_read,
171            message_length: bytes_read,
172            address,
173            ancillary_data: vec![],
174        })
175    }
176
177    fn write(
178        &self,
179        locked: &mut Locked<FileOpsCore>,
180        _socket: &Socket,
181        current_task: &CurrentTask,
182        data: &mut dyn InputBuffer,
183        _dest_address: &mut Option<SocketAddress>,
184        _ancillary_data: &mut Vec<AncillaryData>,
185    ) -> Result<usize, Errno> {
186        let file = {
187            let inner = self.lock();
188            match &inner.state {
189                VsockSocketState::Connected { file, .. } => file.clone(),
190                _ => return error!(EBADF),
191            }
192        };
193        current_task.override_creds(Credentials::root(), || file.write(locked, current_task, data))
194    }
195
196    fn wait_async(
197        &self,
198        locked: &mut Locked<FileOpsCore>,
199        _socket: &Socket,
200        current_task: &CurrentTask,
201        waiter: &Waiter,
202        events: FdEvents,
203        handler: EventHandler,
204    ) -> WaitCanceler {
205        let inner = self.lock();
206        match &inner.state {
207            VsockSocketState::Connected { file, .. } => file
208                .wait_async(locked, current_task, waiter, events, handler)
209                .expect("vsock socket should be connected to a file that can be waited on"),
210            _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
211        }
212    }
213
214    fn query_events(
215        &self,
216        locked: &mut Locked<FileOpsCore>,
217        _socket: &Socket,
218        current_task: &CurrentTask,
219    ) -> Result<FdEvents, Errno> {
220        self.lock().query_events(locked, current_task)
221    }
222
223    fn shutdown(
224        &self,
225        _locked: &mut Locked<FileOpsCore>,
226        _socket: &Socket,
227        _how: SocketShutdownFlags,
228    ) -> Result<(), Errno> {
229        self.lock().state = VsockSocketState::Closed;
230        Ok(())
231    }
232
233    fn close(
234        &self,
235        locked: &mut Locked<FileOpsCore>,
236        _current_task: &CurrentTask,
237        socket: &Socket,
238    ) {
239        // Call to shutdown should never fail, so unwrap is OK
240        self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
241            .unwrap();
242    }
243
244    fn getsockname(
245        &self,
246        _locked: &mut Locked<FileOpsCore>,
247        socket: &Socket,
248    ) -> Result<SocketAddress, Errno> {
249        let inner = self.lock();
250        if let Some(address) = &inner.address {
251            Ok(address.clone())
252        } else {
253            Ok(SocketAddress::default_for_domain(socket.domain))
254        }
255    }
256
257    fn getpeername(
258        &self,
259        _locked: &mut Locked<FileOpsCore>,
260        _socket: &Socket,
261    ) -> Result<SocketAddress, Errno> {
262        let inner = self.lock();
263        match &inner.state {
264            VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
265            _ => {
266                error!(ENOTCONN)
267            }
268        }
269    }
270}
271
272impl VsockSocket {
273    pub fn remote_connection<L>(
274        &self,
275        locked: &mut Locked<L>,
276        socket: &Socket,
277        current_task: &CurrentTask,
278        file: FileHandle,
279    ) -> Result<(), Errno>
280    where
281        L: LockEqualOrBefore<FileOpsCore>,
282    {
283        // we only allow non-blocking files here, so that
284        // read and write on file can return EAGAIN.
285        assert!(file.flags().contains(OpenFlags::NONBLOCK));
286        if socket.socket_type != SocketType::Stream {
287            return error!(ENOTSUP);
288        }
289        if socket.domain != SocketDomain::Vsock {
290            return error!(EINVAL);
291        }
292
293        let mut inner = self.lock();
294        match &mut inner.state {
295            VsockSocketState::Listening(queue) => {
296                if queue.sockets.len() >= queue.backlog {
297                    return error!(EAGAIN);
298                }
299                let remote_socket = Socket::new(
300                    locked,
301                    current_task,
302                    SocketDomain::Vsock,
303                    SocketType::Stream,
304                    SocketProtocol::default(),
305                    /* kernel_private = */ true,
306                )?;
307                downcast_socket_to_vsock(&remote_socket).lock().state =
308                    VsockSocketState::Connected {
309                        file,
310                        peer_addr: SocketAddress::Vsock {
311                            port: u32::MAX,
312                            cid: starnix_uapi::VMADDR_CID_HOST,
313                        },
314                    };
315                queue.sockets.push_back(remote_socket);
316                inner.waiters.notify_fd_events(FdEvents::POLLIN);
317                Ok(())
318            }
319            _ => error!(EINVAL),
320        }
321    }
322}
323
324impl VsockSocketInner {
325    fn query_events<L>(
326        &self,
327        locked: &mut Locked<L>,
328        current_task: &CurrentTask,
329    ) -> Result<FdEvents, Errno>
330    where
331        L: LockEqualOrBefore<FileOpsCore>,
332    {
333        Ok(match &self.state {
334            VsockSocketState::Disconnected => FdEvents::empty(),
335            VsockSocketState::Connected { file, .. } => current_task
336                .override_creds(Credentials::root(), || file.query_events(locked, current_task))?,
337            VsockSocketState::Listening(queue) => {
338                if !queue.sockets.is_empty() {
339                    FdEvents::POLLIN
340                } else {
341                    FdEvents::empty()
342                }
343            }
344            VsockSocketState::Closed => FdEvents::POLLHUP,
345        })
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::fs::fuchsia::create_fuchsia_pipe;
353    use crate::mm::PAGE_SIZE;
354    use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
355    use crate::testing::spawn_kernel_and_run;
356    use crate::vfs::EpollFileObject;
357    use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
358    use crate::vfs::socket::SocketFile;
359    use futures::executor::block_on;
360    use starnix_sync::Unlocked;
361    use starnix_uapi::vfs::EpollEvent;
362    use syncio::Zxio;
363    use zx::HandleBased;
364
365    #[::fuchsia::test]
366    async fn test_vsock_socket() {
367        spawn_kernel_and_run(async |locked, current_task| {
368            let (fs1, fs2) = fidl::Socket::create_stream();
369            const VSOCK_PORT: u32 = 5555;
370
371            let listen_socket = Socket::new(
372                locked,
373                &current_task,
374                SocketDomain::Vsock,
375                SocketType::Stream,
376                SocketProtocol::default(),
377                /* kernel_private = */ false,
378            )
379            .expect("Failed to create socket.");
380            current_task
381                .abstract_vsock_namespace
382                .bind(locked, &current_task, VSOCK_PORT, &listen_socket)
383                .expect("Failed to bind socket.");
384            listen_socket.listen(locked, &current_task, 10).expect("Failed to listen.");
385
386            let listen_socket = current_task
387                .abstract_vsock_namespace
388                .lookup(&VSOCK_PORT)
389                .expect("Failed to look up listening socket.");
390            let remote = create_fuchsia_pipe(
391                locked,
392                &current_task,
393                fs2,
394                OpenFlags::RDWR | OpenFlags::NONBLOCK,
395            )
396            .unwrap();
397            listen_socket
398                .downcast_socket::<VsockSocket>()
399                .unwrap()
400                .remote_connection(locked, &listen_socket, &current_task, remote)
401                .unwrap();
402
403            let server_socket = listen_socket.accept(locked, &current_task).unwrap();
404
405            let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
406            assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
407            let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
408            let read_message_info = server_socket
409                .read(locked, &current_task, &mut buffer_iterator, SocketMessageFlags::empty())
410                .unwrap();
411            assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
412            assert_eq!(buffer_iterator.data(), test_bytes_in);
413
414            let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
415            let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
416            server_socket
417                .write(locked, &current_task, &mut buffer_iterator, &mut None, &mut vec![])
418                .unwrap();
419            assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
420
421            let mut read_back_buf = [0u8; 100];
422            assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
423            assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
424
425            server_socket.close(locked, &current_task);
426            listen_socket.close(locked, &current_task);
427        })
428        .await;
429    }
430
431    #[::fuchsia::test]
432    async fn test_vsock_write_while_read() {
433        spawn_kernel_and_run(async |locked, current_task| {
434            let kernel = current_task.kernel();
435            let (fs1, fs2) = fidl::Socket::create_stream();
436            let socket = Socket::new(
437                locked,
438                &current_task,
439                SocketDomain::Vsock,
440                SocketType::Stream,
441                SocketProtocol::default(),
442                /* kernel_private = */ false,
443            )
444            .expect("Failed to create socket.");
445            let remote = create_fuchsia_pipe(
446                locked,
447                &current_task,
448                fs2,
449                OpenFlags::RDWR | OpenFlags::NONBLOCK,
450            )
451            .unwrap();
452            downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
453                file: remote,
454                peer_addr: SocketAddress::Vsock {
455                    port: u32::MAX,
456                    cid: starnix_uapi::VMADDR_CID_HOST,
457                },
458            };
459            let socket_file =
460                SocketFile::from_socket(locked, &current_task, socket, OpenFlags::RDWR, false)
461                    .expect("Failed to create socket file.");
462
463            const XFER_SIZE: usize = 42;
464
465            let socket_clone = socket_file.clone();
466            let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
467                let bytes_read = socket_clone
468                    .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
469                    .unwrap();
470                assert_eq!(XFER_SIZE, bytes_read);
471            };
472            let (result, req) =
473                SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
474            kernel.kthreads.spawner().spawn_from_request(req);
475
476            // Wait for the thread to become blocked on the read.
477            std::thread::sleep(std::time::Duration::from_secs(2));
478
479            socket_file
480                .write(locked, &current_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
481                .unwrap();
482
483            let mut buffer = [0u8; 1024];
484            assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
485            assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
486            block_on(result).unwrap();
487        })
488        .await;
489    }
490
491    #[::fuchsia::test]
492    async fn test_vsock_poll() {
493        spawn_kernel_and_run(async |locked, current_task| {
494            let (client, server) = zx::Socket::create_stream();
495            let pipe = create_fuchsia_pipe(locked, &current_task, client, OpenFlags::RDWR)
496                .expect("create_fuchsia_pipe");
497            let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
498            let socket_object = Socket::new(
499                locked,
500                &current_task,
501                SocketDomain::Vsock,
502                SocketType::Stream,
503                SocketProtocol::default(),
504                /* kernel_private = */ false,
505            )
506            .expect("Failed to create socket.");
507            downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
508                file: pipe,
509                peer_addr: SocketAddress::Vsock {
510                    port: u32::MAX,
511                    cid: starnix_uapi::VMADDR_CID_HOST,
512                },
513            };
514            let socket = SocketFile::from_socket(
515                locked,
516                &current_task,
517                socket_object.clone(),
518                OpenFlags::RDWR,
519                false,
520            )
521            .expect("Failed to create socket file.");
522
523            assert_eq!(
524                socket.query_events(locked, &current_task),
525                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
526            );
527
528            let epoll_object = EpollFileObject::new_file(locked, &current_task);
529            let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
530            let event = EpollEvent::new(FdEvents::POLLIN, 0);
531            epoll_file
532                .add(locked, &current_task, &socket, &epoll_object, event)
533                .expect("poll_file.add");
534
535            let fds = epoll_file
536                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
537                .expect("wait");
538            assert!(fds.is_empty());
539
540            assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
541
542            assert_eq!(
543                socket.query_events(locked, &current_task),
544                Ok(FdEvents::POLLOUT
545                    | FdEvents::POLLWRNORM
546                    | FdEvents::POLLIN
547                    | FdEvents::POLLRDNORM)
548            );
549            let fds = epoll_file
550                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
551                .expect("wait");
552            assert_eq!(fds.len(), 1);
553
554            assert_eq!(
555                socket.read(locked, &current_task, &mut VecOutputBuffer::new(64)).expect("read"),
556                1
557            );
558
559            assert_eq!(
560                socket.query_events(locked, &current_task),
561                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
562            );
563            let fds = epoll_file
564                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
565                .expect("wait");
566            assert!(fds.is_empty());
567        })
568        .await;
569    }
570}