starnix_core/vfs/socket/
socket_vsock.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::task::{CurrentTask, EventHandler, FullCredentials, WaitCanceler, WaitQueue, Waiter};
6use crate::vfs::FileHandle;
7use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
8use crate::vfs::socket::{
9    AcceptQueue, DEFAULT_LISTEN_BACKLOG, Socket, SocketAddress, SocketDomain, SocketHandle,
10    SocketMessageFlags, SocketOps, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType,
11};
12use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
13use starnix_uapi::errors::Errno;
14use starnix_uapi::open_flags::OpenFlags;
15use starnix_uapi::vfs::FdEvents;
16use starnix_uapi::{errno, error, ucred};
17
18// An implementation of AF_VSOCK.
19// See https://man7.org/linux/man-pages/man7/vsock.7.html
20
21pub struct VsockSocket {
22    inner: Mutex<VsockSocketInner>,
23}
24
25struct VsockSocketInner {
26    /// The address that this socket has been bound to, if it has been bound.
27    address: Option<SocketAddress>,
28
29    // WaitQueue for listening sockets.
30    waiters: WaitQueue,
31
32    // state of the vsock. Contains a handle to a ZxioBackedSocket when connected.
33    state: VsockSocketState,
34}
35
36enum VsockSocketState {
37    /// The socket has not been connected.
38    Disconnected,
39
40    /// The socket has had `listen` called and can accept incoming connections.
41    Listening(AcceptQueue),
42
43    /// The socket is connected to a ZxioBackedSocket.
44    Connected { file: FileHandle, peer_addr: SocketAddress },
45
46    /// The socket is closed.
47    Closed,
48}
49
50fn downcast_socket_to_vsock(socket: &Socket) -> &VsockSocket {
51    // It is a programing error if we are downcasting
52    // a different type of socket as sockets from different families
53    // should not communicate, so unwrapping here
54    // will let us know that.
55    socket.downcast_socket::<VsockSocket>().unwrap()
56}
57
58impl VsockSocket {
59    pub fn new(_socket_type: SocketType) -> VsockSocket {
60        VsockSocket {
61            inner: Mutex::new(VsockSocketInner {
62                address: None,
63                waiters: WaitQueue::default(),
64                state: VsockSocketState::Disconnected,
65            }),
66        }
67    }
68
69    /// Locks and returns the inner state of the Socket.
70    fn lock(&self) -> starnix_sync::MutexGuard<'_, VsockSocketInner> {
71        self.inner.lock()
72    }
73}
74
75impl SocketOps for VsockSocket {
76    // Connect with Vsock sockets is not allowed as
77    // we only connect from the enclosing OK.
78    fn connect(
79        &self,
80        _locked: &mut Locked<FileOpsCore>,
81        _socket: &SocketHandle,
82        _current_task: &CurrentTask,
83        _peer: SocketPeer,
84    ) -> Result<(), Errno> {
85        error!(EPROTOTYPE)
86    }
87
88    fn listen(
89        &self,
90        _locked: &mut Locked<FileOpsCore>,
91        _socket: &Socket,
92        backlog: i32,
93        _credentials: ucred,
94    ) -> Result<(), Errno> {
95        let mut inner = self.lock();
96        let is_bound = inner.address.is_some();
97        let backlog = if backlog < 0 { DEFAULT_LISTEN_BACKLOG } else { backlog as usize };
98        match &mut inner.state {
99            VsockSocketState::Disconnected if is_bound => {
100                inner.state = VsockSocketState::Listening(AcceptQueue::new(backlog));
101                Ok(())
102            }
103            VsockSocketState::Listening(queue) => {
104                queue.set_backlog(backlog)?;
105                Ok(())
106            }
107            _ => error!(EINVAL),
108        }
109    }
110
111    fn accept(
112        &self,
113        _locked: &mut Locked<FileOpsCore>,
114        socket: &Socket,
115        _current_task: &CurrentTask,
116    ) -> Result<SocketHandle, Errno> {
117        match socket.socket_type {
118            SocketType::Stream | SocketType::SeqPacket => {}
119            _ => return error!(EOPNOTSUPP),
120        }
121        let mut inner = self.lock();
122        let queue = match &mut inner.state {
123            VsockSocketState::Listening(queue) => queue,
124            _ => return error!(EINVAL),
125        };
126        let socket = queue.sockets.pop_front().ok_or_else(|| errno!(EAGAIN))?;
127        Ok(socket)
128    }
129
130    fn bind(
131        &self,
132        _locked: &mut Locked<FileOpsCore>,
133        _socket: &Socket,
134        _current_task: &CurrentTask,
135        socket_address: SocketAddress,
136    ) -> Result<(), Errno> {
137        match socket_address {
138            SocketAddress::Vsock { .. } => {}
139            _ => return error!(EINVAL),
140        }
141        let mut inner = self.lock();
142        if inner.address.is_some() {
143            return error!(EINVAL);
144        }
145        inner.address = Some(socket_address);
146        Ok(())
147    }
148
149    fn read(
150        &self,
151        locked: &mut Locked<FileOpsCore>,
152        _socket: &Socket,
153        current_task: &CurrentTask,
154        data: &mut dyn OutputBuffer,
155        _flags: SocketMessageFlags,
156    ) -> Result<MessageReadInfo, Errno> {
157        let (address, file) = {
158            let inner = self.lock();
159            let address = inner.address.clone();
160
161            match &inner.state {
162                VsockSocketState::Connected { file, .. } => (address, file.clone()),
163                _ => return error!(EBADF),
164            }
165        };
166        let bytes_read = current_task.override_creds(FullCredentials::for_kernel(), || {
167            file.read(locked, current_task, data)
168        })?;
169        Ok(MessageReadInfo {
170            bytes_read,
171            message_length: bytes_read,
172            address,
173            ancillary_data: vec![],
174        })
175    }
176
177    fn write(
178        &self,
179        locked: &mut Locked<FileOpsCore>,
180        _socket: &Socket,
181        current_task: &CurrentTask,
182        data: &mut dyn InputBuffer,
183        _dest_address: &mut Option<SocketAddress>,
184        _ancillary_data: &mut Vec<AncillaryData>,
185    ) -> Result<usize, Errno> {
186        let file = {
187            let inner = self.lock();
188            match &inner.state {
189                VsockSocketState::Connected { file, .. } => file.clone(),
190                _ => return error!(EBADF),
191            }
192        };
193        current_task.override_creds(FullCredentials::for_kernel(), || {
194            file.write(locked, current_task, data)
195        })
196    }
197
198    fn wait_async(
199        &self,
200        locked: &mut Locked<FileOpsCore>,
201        _socket: &Socket,
202        current_task: &CurrentTask,
203        waiter: &Waiter,
204        events: FdEvents,
205        handler: EventHandler,
206    ) -> WaitCanceler {
207        let inner = self.lock();
208        match &inner.state {
209            VsockSocketState::Connected { file, .. } => file
210                .wait_async(locked, current_task, waiter, events, handler)
211                .expect("vsock socket should be connected to a file that can be waited on"),
212            _ => inner.waiters.wait_async_fd_events(waiter, events, handler),
213        }
214    }
215
216    fn query_events(
217        &self,
218        locked: &mut Locked<FileOpsCore>,
219        _socket: &Socket,
220        current_task: &CurrentTask,
221    ) -> Result<FdEvents, Errno> {
222        self.lock().query_events(locked, current_task)
223    }
224
225    fn shutdown(
226        &self,
227        _locked: &mut Locked<FileOpsCore>,
228        _socket: &Socket,
229        _how: SocketShutdownFlags,
230    ) -> Result<(), Errno> {
231        self.lock().state = VsockSocketState::Closed;
232        Ok(())
233    }
234
235    fn close(
236        &self,
237        locked: &mut Locked<FileOpsCore>,
238        _current_task: &CurrentTask,
239        socket: &Socket,
240    ) {
241        // Call to shutdown should never fail, so unwrap is OK
242        self.shutdown(locked, socket, SocketShutdownFlags::READ | SocketShutdownFlags::WRITE)
243            .unwrap();
244    }
245
246    fn getsockname(
247        &self,
248        _locked: &mut Locked<FileOpsCore>,
249        socket: &Socket,
250    ) -> Result<SocketAddress, Errno> {
251        let inner = self.lock();
252        if let Some(address) = &inner.address {
253            Ok(address.clone())
254        } else {
255            Ok(SocketAddress::default_for_domain(socket.domain))
256        }
257    }
258
259    fn getpeername(
260        &self,
261        _locked: &mut Locked<FileOpsCore>,
262        _socket: &Socket,
263    ) -> Result<SocketAddress, Errno> {
264        let inner = self.lock();
265        match &inner.state {
266            VsockSocketState::Connected { peer_addr, .. } => Ok(peer_addr.clone()),
267            _ => {
268                error!(ENOTCONN)
269            }
270        }
271    }
272}
273
274impl VsockSocket {
275    pub fn remote_connection<L>(
276        &self,
277        locked: &mut Locked<L>,
278        socket: &Socket,
279        current_task: &CurrentTask,
280        file: FileHandle,
281    ) -> Result<(), Errno>
282    where
283        L: LockEqualOrBefore<FileOpsCore>,
284    {
285        // we only allow non-blocking files here, so that
286        // read and write on file can return EAGAIN.
287        assert!(file.flags().contains(OpenFlags::NONBLOCK));
288        if socket.socket_type != SocketType::Stream {
289            return error!(ENOTSUP);
290        }
291        if socket.domain != SocketDomain::Vsock {
292            return error!(EINVAL);
293        }
294
295        let mut inner = self.lock();
296        match &mut inner.state {
297            VsockSocketState::Listening(queue) => {
298                if queue.sockets.len() >= queue.backlog {
299                    return error!(EAGAIN);
300                }
301                let remote_socket = Socket::new(
302                    locked,
303                    current_task,
304                    SocketDomain::Vsock,
305                    SocketType::Stream,
306                    SocketProtocol::default(),
307                    /* kernel_private = */ false,
308                )?;
309                downcast_socket_to_vsock(&remote_socket).lock().state =
310                    VsockSocketState::Connected {
311                        file,
312                        peer_addr: SocketAddress::Vsock {
313                            port: u32::MAX,
314                            cid: starnix_uapi::VMADDR_CID_HOST,
315                        },
316                    };
317                queue.sockets.push_back(remote_socket);
318                inner.waiters.notify_fd_events(FdEvents::POLLIN);
319                Ok(())
320            }
321            _ => error!(EINVAL),
322        }
323    }
324}
325
326impl VsockSocketInner {
327    fn query_events<L>(
328        &self,
329        locked: &mut Locked<L>,
330        current_task: &CurrentTask,
331    ) -> Result<FdEvents, Errno>
332    where
333        L: LockEqualOrBefore<FileOpsCore>,
334    {
335        Ok(match &self.state {
336            VsockSocketState::Disconnected => FdEvents::empty(),
337            VsockSocketState::Connected { file, .. } => current_task
338                .override_creds(FullCredentials::for_kernel(), || {
339                    file.query_events(locked, current_task)
340                })?,
341            VsockSocketState::Listening(queue) => {
342                if !queue.sockets.is_empty() {
343                    FdEvents::POLLIN
344                } else {
345                    FdEvents::empty()
346                }
347            }
348            VsockSocketState::Closed => FdEvents::POLLHUP,
349        })
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::fs::fuchsia::create_fuchsia_pipe;
357    use crate::mm::PAGE_SIZE;
358    use crate::task::dynamic_thread_spawner::SpawnRequestBuilder;
359    use crate::testing::spawn_kernel_and_run;
360    use crate::vfs::EpollFileObject;
361    use crate::vfs::buffers::{VecInputBuffer, VecOutputBuffer};
362    use crate::vfs::socket::SocketFile;
363    use futures::executor::block_on;
364    use starnix_sync::Unlocked;
365    use starnix_uapi::vfs::EpollEvent;
366    use syncio::Zxio;
367    use zx::HandleBased;
368
369    #[::fuchsia::test]
370    async fn test_vsock_socket() {
371        spawn_kernel_and_run(async |locked, current_task| {
372            let (fs1, fs2) = fidl::Socket::create_stream();
373            const VSOCK_PORT: u32 = 5555;
374
375            let listen_socket = Socket::new(
376                locked,
377                &current_task,
378                SocketDomain::Vsock,
379                SocketType::Stream,
380                SocketProtocol::default(),
381                /* kernel_private = */ false,
382            )
383            .expect("Failed to create socket.");
384            current_task
385                .abstract_vsock_namespace
386                .bind(locked, &current_task, VSOCK_PORT, &listen_socket)
387                .expect("Failed to bind socket.");
388            listen_socket.listen(locked, &current_task, 10).expect("Failed to listen.");
389
390            let listen_socket = current_task
391                .abstract_vsock_namespace
392                .lookup(&VSOCK_PORT)
393                .expect("Failed to look up listening socket.");
394            let remote = create_fuchsia_pipe(
395                locked,
396                &current_task,
397                fs2,
398                OpenFlags::RDWR | OpenFlags::NONBLOCK,
399            )
400            .unwrap();
401            listen_socket
402                .downcast_socket::<VsockSocket>()
403                .unwrap()
404                .remote_connection(locked, &listen_socket, &current_task, remote)
405                .unwrap();
406
407            let server_socket = listen_socket.accept(locked, &current_task).unwrap();
408
409            let test_bytes_in: [u8; 5] = [0, 1, 2, 3, 4];
410            assert_eq!(fs1.write(&test_bytes_in[..]).unwrap(), test_bytes_in.len());
411            let mut buffer_iterator = VecOutputBuffer::new(*PAGE_SIZE as usize);
412            let read_message_info = server_socket
413                .read(locked, &current_task, &mut buffer_iterator, SocketMessageFlags::empty())
414                .unwrap();
415            assert_eq!(read_message_info.bytes_read, test_bytes_in.len());
416            assert_eq!(buffer_iterator.data(), test_bytes_in);
417
418            let test_bytes_out: [u8; 10] = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0];
419            let mut buffer_iterator = VecInputBuffer::new(&test_bytes_out);
420            server_socket
421                .write(locked, &current_task, &mut buffer_iterator, &mut None, &mut vec![])
422                .unwrap();
423            assert_eq!(buffer_iterator.bytes_read(), test_bytes_out.len());
424
425            let mut read_back_buf = [0u8; 100];
426            assert_eq!(test_bytes_out.len(), fs1.read(&mut read_back_buf).unwrap());
427            assert_eq!(&read_back_buf[..test_bytes_out.len()], &test_bytes_out);
428
429            server_socket.close(locked, &current_task);
430            listen_socket.close(locked, &current_task);
431        })
432        .await;
433    }
434
435    #[::fuchsia::test]
436    async fn test_vsock_write_while_read() {
437        spawn_kernel_and_run(async |locked, current_task| {
438            let kernel = current_task.kernel();
439            let (fs1, fs2) = fidl::Socket::create_stream();
440            let socket = Socket::new(
441                locked,
442                &current_task,
443                SocketDomain::Vsock,
444                SocketType::Stream,
445                SocketProtocol::default(),
446                /* kernel_private = */ false,
447            )
448            .expect("Failed to create socket.");
449            let remote = create_fuchsia_pipe(
450                locked,
451                &current_task,
452                fs2,
453                OpenFlags::RDWR | OpenFlags::NONBLOCK,
454            )
455            .unwrap();
456            downcast_socket_to_vsock(&socket).lock().state = VsockSocketState::Connected {
457                file: remote,
458                peer_addr: SocketAddress::Vsock {
459                    port: u32::MAX,
460                    cid: starnix_uapi::VMADDR_CID_HOST,
461                },
462            };
463            let socket_file =
464                SocketFile::from_socket(locked, &current_task, socket, OpenFlags::RDWR, false)
465                    .expect("Failed to create socket file.");
466
467            const XFER_SIZE: usize = 42;
468
469            let socket_clone = socket_file.clone();
470            let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| {
471                let bytes_read = socket_clone
472                    .read(locked, current_task, &mut VecOutputBuffer::new(XFER_SIZE))
473                    .unwrap();
474                assert_eq!(XFER_SIZE, bytes_read);
475            };
476            let (result, req) =
477                SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result();
478            kernel.kthreads.spawner().spawn_from_request(req);
479
480            // Wait for the thread to become blocked on the read.
481            std::thread::sleep(std::time::Duration::from_secs(2));
482
483            socket_file
484                .write(locked, &current_task, &mut VecInputBuffer::new(&[0; XFER_SIZE]))
485                .unwrap();
486
487            let mut buffer = [0u8; 1024];
488            assert_eq!(XFER_SIZE, fs1.read(&mut buffer).unwrap());
489            assert_eq!(XFER_SIZE, fs1.write(&buffer[..XFER_SIZE]).unwrap());
490            block_on(result).unwrap();
491        })
492        .await;
493    }
494
495    #[::fuchsia::test]
496    async fn test_vsock_poll() {
497        spawn_kernel_and_run(async |locked, current_task| {
498            let (client, server) = zx::Socket::create_stream();
499            let pipe = create_fuchsia_pipe(locked, &current_task, client, OpenFlags::RDWR)
500                .expect("create_fuchsia_pipe");
501            let server_zxio = Zxio::create(server.into_handle()).expect("Zxio::create");
502            let socket_object = Socket::new(
503                locked,
504                &current_task,
505                SocketDomain::Vsock,
506                SocketType::Stream,
507                SocketProtocol::default(),
508                /* kernel_private = */ false,
509            )
510            .expect("Failed to create socket.");
511            downcast_socket_to_vsock(&socket_object).lock().state = VsockSocketState::Connected {
512                file: pipe,
513                peer_addr: SocketAddress::Vsock {
514                    port: u32::MAX,
515                    cid: starnix_uapi::VMADDR_CID_HOST,
516                },
517            };
518            let socket = SocketFile::from_socket(
519                locked,
520                &current_task,
521                socket_object.clone(),
522                OpenFlags::RDWR,
523                false,
524            )
525            .expect("Failed to create socket file.");
526
527            assert_eq!(
528                socket.query_events(locked, &current_task),
529                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
530            );
531
532            let epoll_object = EpollFileObject::new_file(locked, &current_task);
533            let epoll_file = epoll_object.downcast_file::<EpollFileObject>().unwrap();
534            let event = EpollEvent::new(FdEvents::POLLIN, 0);
535            epoll_file
536                .add(locked, &current_task, &socket, &epoll_object, event)
537                .expect("poll_file.add");
538
539            let fds = epoll_file
540                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
541                .expect("wait");
542            assert!(fds.is_empty());
543
544            assert_eq!(server_zxio.write(&[0]).expect("write"), 1);
545
546            assert_eq!(
547                socket.query_events(locked, &current_task),
548                Ok(FdEvents::POLLOUT
549                    | FdEvents::POLLWRNORM
550                    | FdEvents::POLLIN
551                    | FdEvents::POLLRDNORM)
552            );
553            let fds = epoll_file
554                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
555                .expect("wait");
556            assert_eq!(fds.len(), 1);
557
558            assert_eq!(
559                socket.read(locked, &current_task, &mut VecOutputBuffer::new(64)).expect("read"),
560                1
561            );
562
563            assert_eq!(
564                socket.query_events(locked, &current_task),
565                Ok(FdEvents::POLLOUT | FdEvents::POLLWRNORM)
566            );
567            let fds = epoll_file
568                .wait(locked, &current_task, 1, zx::MonotonicInstant::ZERO)
569                .expect("wait");
570            assert!(fds.is_empty());
571        })
572        .await;
573    }
574}