starnix_core/vfs/socket/
socket.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::{
6    NetlinkFamily, QipcrtrSocket, SocketAddress, SocketDomain, SocketFile, SocketMessageFlags,
7    SocketProtocol, SocketShutdownFlags, SocketType, UnixSocket, VsockSocket, ZxioBackedSocket,
8    new_netlink_socket,
9};
10use crate::mm::MemoryAccessorExt;
11use crate::security;
12use crate::syscalls::time::TimeValPtr;
13use crate::task::{CurrentTask, EventHandler, WaitCanceler, Waiter};
14use crate::vfs::buffers::{AncillaryData, InputBuffer, MessageReadInfo, OutputBuffer};
15use crate::vfs::{DowncastedFile, FileHandle, FileObject, FsNodeHandle, default_ioctl};
16use starnix_logging::track_stub;
17use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex, Unlocked};
18use starnix_syscalls::{SyscallArg, SyscallResult};
19use starnix_types::time::{duration_from_timeval, timeval_from_duration};
20use starnix_types::user_buffer::UserBuffer;
21use starnix_uapi::as_any::AsAny;
22use starnix_uapi::auth::CAP_NET_RAW;
23use starnix_uapi::errors::{ENOTTY, Errno};
24use starnix_uapi::user_address::MappingMultiArchUserRef;
25use starnix_uapi::vfs::FdEvents;
26use starnix_uapi::{
27    SO_DOMAIN, SO_PROTOCOL, SO_RCVTIMEO, SO_SNDTIMEO, SO_TYPE, SOL_SOCKET, errno, error, uapi,
28};
29use std::collections::VecDeque;
30use std::sync::Arc;
31use std::sync::atomic::Ordering;
32use zerocopy::FromBytes;
33
34pub const DEFAULT_LISTEN_BACKLOG: usize = 1024;
35
36/// TODO(https://fxbug.dev/477273398"): These come from Android, and are currently stubbed out.
37const SO_ANDROID_DROP_REASON: u32 = 0xAD01D01;
38const ANDROID_DROP_REASON_NONE: u64 = 0;
39
40pub trait SocketOps: Send + Sync + AsAny {
41    /// Returns the domain, type and protocol of the socket. This is only used for socket that are
42    /// build without previous knowledge of this information, and can be ignored if all sockets are
43    /// build with it.
44    fn get_socket_info(&self) -> Result<(SocketDomain, SocketType, SocketProtocol), Errno> {
45        // This should not be used by most socket type that are created with their domain, type and
46        // protocol.
47        error!(EINVAL)
48    }
49
50    /// Connect the `socket` to the listening `peer`. On success
51    /// a new socket is created and added to the accept queue.
52    fn connect(
53        &self,
54        locked: &mut Locked<FileOpsCore>,
55        socket: &SocketHandle,
56        current_task: &CurrentTask,
57        peer: SocketPeer,
58    ) -> Result<(), Errno>;
59
60    /// Start listening at the bound address for `connect` calls.
61    fn listen(
62        &self,
63        locked: &mut Locked<FileOpsCore>,
64        socket: &Socket,
65        backlog: i32,
66        credentials: uapi::ucred,
67    ) -> Result<(), Errno>;
68
69    /// Returns the eariest socket on the accept queue of this
70    /// listening socket. Returns EAGAIN if the queue is empty.
71    fn accept(
72        &self,
73        locked: &mut Locked<FileOpsCore>,
74        socket: &Socket,
75        current_task: &CurrentTask,
76    ) -> Result<SocketHandle, Errno>;
77
78    /// Binds this socket to a `socket_address`.
79    ///
80    /// Returns an error if the socket could not be bound.
81    fn bind(
82        &self,
83        locked: &mut Locked<FileOpsCore>,
84        socket: &Socket,
85        current_task: &CurrentTask,
86        socket_address: SocketAddress,
87    ) -> Result<(), Errno>;
88
89    /// Reads the specified number of bytes from the socket, if possible.
90    ///
91    /// # Parameters
92    /// - `task`: The task to which the user buffers belong (i.e., the task to which the read bytes
93    ///           are written.
94    /// - `data`: The buffers to write the read data into.
95    ///
96    /// Returns the number of bytes that were written to the user buffers, as well as any ancillary
97    /// data associated with the read messages.
98    fn read(
99        &self,
100        locked: &mut Locked<FileOpsCore>,
101        socket: &Socket,
102        current_task: &CurrentTask,
103        data: &mut dyn OutputBuffer,
104        flags: SocketMessageFlags,
105    ) -> Result<MessageReadInfo, Errno>;
106
107    /// Writes the data in the provided user buffers to this socket.
108    ///
109    /// # Parameters
110    /// - `task`: The task to which the user buffers belong, used to read the memory.
111    /// - `data`: The data to write to the socket.
112    /// - `ancillary_data`: Optional ancillary data (a.k.a., control message) to write.
113    ///
114    /// Advances the iterator to indicate how much was actually written.
115    fn write(
116        &self,
117        locked: &mut Locked<FileOpsCore>,
118        socket: &Socket,
119        current_task: &CurrentTask,
120        data: &mut dyn InputBuffer,
121        dest_address: &mut Option<SocketAddress>,
122        ancillary_data: &mut Vec<AncillaryData>,
123    ) -> Result<usize, Errno>;
124
125    /// Queues an asynchronous wait for the specified `events`
126    /// on the `waiter`. Note that no wait occurs until a
127    /// wait functions is called on the `waiter`.
128    ///
129    /// # Parameters
130    /// - `waiter`: The Waiter that can be waited on, for example by
131    ///             calling Waiter::wait_until.
132    /// - `events`: The events that will trigger the waiter to wake up.
133    /// - `handler`: A handler that will be called on wake-up.
134    /// Returns a WaitCanceler that can be used to cancel the wait.
135    fn wait_async(
136        &self,
137        locked: &mut Locked<FileOpsCore>,
138        socket: &Socket,
139        current_task: &CurrentTask,
140        waiter: &Waiter,
141        events: FdEvents,
142        handler: EventHandler,
143    ) -> WaitCanceler;
144
145    /// Return the events that are currently active on the `socket`.
146    fn query_events(
147        &self,
148        locked: &mut Locked<FileOpsCore>,
149        socket: &Socket,
150        current_task: &CurrentTask,
151    ) -> Result<FdEvents, Errno>;
152
153    /// Shuts down this socket according to how, preventing any future reads and/or writes.
154    ///
155    /// Used by the shutdown syscalls.
156    fn shutdown(
157        &self,
158        locked: &mut Locked<FileOpsCore>,
159        socket: &Socket,
160        how: SocketShutdownFlags,
161    ) -> Result<(), Errno>;
162
163    /// Close this socket.
164    ///
165    /// Called by SocketFile when the file descriptor that is holding this
166    /// socket is closed.
167    ///
168    /// Close differs from shutdown in two ways. First, close will call
169    /// mark_peer_closed_with_unread_data if this socket has unread data,
170    /// which changes how read() behaves on that socket. Second, close
171    /// transitions the internal state of this socket to Closed, which breaks
172    /// the reference cycle that exists in the connected state.
173    fn close(&self, locked: &mut Locked<FileOpsCore>, current_task: &CurrentTask, socket: &Socket);
174
175    /// Returns the name of this socket.
176    ///
177    /// The name is derived from the address and domain. A socket
178    /// will always have a name, even if it is not bound to an address.
179    fn getsockname(
180        &self,
181        locked: &mut Locked<FileOpsCore>,
182        socket: &Socket,
183    ) -> Result<SocketAddress, Errno>;
184
185    /// Returns the name of the peer of this socket, if such a peer exists.
186    ///
187    /// Returns an error if the socket is not connected.
188    fn getpeername(
189        &self,
190        locked: &mut Locked<FileOpsCore>,
191        socket: &Socket,
192    ) -> Result<SocketAddress, Errno>;
193
194    /// Sets socket-specific options.
195    fn setsockopt(
196        &self,
197        _locked: &mut Locked<FileOpsCore>,
198        _socket: &Socket,
199        _current_task: &CurrentTask,
200        _level: u32,
201        _optname: u32,
202        _optval: SockOptValue,
203    ) -> Result<(), Errno> {
204        error!(ENOPROTOOPT)
205    }
206
207    /// Retrieves socket-specific options.
208    fn getsockopt(
209        &self,
210        _locked: &mut Locked<FileOpsCore>,
211        _socket: &Socket,
212        _current_task: &CurrentTask,
213        _level: u32,
214        _optname: u32,
215        _optlen: u32,
216    ) -> Result<Vec<u8>, Errno> {
217        error!(ENOPROTOOPT)
218    }
219
220    /// Implements ioctl.
221    fn ioctl(
222        &self,
223        locked: &mut Locked<Unlocked>,
224        _socket: &Socket,
225        file: &FileObject,
226        current_task: &CurrentTask,
227        request: u32,
228        arg: SyscallArg,
229    ) -> Result<SyscallResult, Errno> {
230        default_ioctl(file, locked, current_task, request, arg)
231    }
232
233    /// Return a handle that allows access to this file descritor through the zxio protocols.
234    ///
235    /// If None is returned, the file will be proxied.
236    fn to_handle(
237        &self,
238        _socket: &Socket,
239        _current_task: &CurrentTask,
240    ) -> Result<Option<zx::NullableHandle>, Errno> {
241        Ok(None)
242    }
243}
244
245/// A `Socket` represents one endpoint of a bidirectional communication channel.
246pub struct Socket {
247    pub(super) ops: Box<dyn SocketOps>,
248
249    /// The domain of this socket.
250    pub domain: SocketDomain,
251
252    /// The type of this socket.
253    pub socket_type: SocketType,
254
255    /// The protocol of this socket.
256    pub protocol: SocketProtocol,
257
258    state: Mutex<SocketState>,
259
260    /// Security module state associated with this socket. Note that the socket's security label is
261    /// applied to the associated `fs_node`.
262    pub security: security::SocketState,
263}
264
265#[derive(Default)]
266struct SocketState {
267    /// The value of SO_RCVTIMEO.
268    receive_timeout: Option<zx::MonotonicDuration>,
269
270    /// The value for SO_SNDTIMEO.
271    send_timeout: Option<zx::MonotonicDuration>,
272
273    /// Reference to the [`crate::vfs::FsNode`] to which this `Socket` is attached.
274    /// `None` until the `Socket` is wrapped into a [`crate::vfs::FileObject`] (e.g. while it is
275    /// still held in a listen queue).
276    fs_node: Option<FsNodeHandle>,
277}
278
279pub type SocketHandle = Arc<Socket>;
280
281#[derive(Clone)]
282pub enum SocketPeer {
283    Handle(SocketHandle),
284    Address(SocketAddress),
285}
286
287// `resolve_protocol()` returns the protocol that should be used for a new
288// socket. `socket()` allows `protocol` parameter to be set 0, in which case the
289// protocol defaults to TCP or UDP depending on the specified `socket_type`.
290fn resolve_protocol(
291    domain: SocketDomain,
292    socket_type: SocketType,
293    protocol: SocketProtocol,
294) -> SocketProtocol {
295    if domain.is_inet() && protocol.as_raw() == 0 {
296        match socket_type {
297            SocketType::Stream => SocketProtocol::TCP,
298            SocketType::Datagram => SocketProtocol::UDP,
299            _ => protocol,
300        }
301    } else {
302        protocol
303    }
304}
305
306fn create_socket_ops(
307    locked: &mut Locked<FileOpsCore>,
308    current_task: &CurrentTask,
309    domain: SocketDomain,
310    socket_type: SocketType,
311    protocol: SocketProtocol,
312) -> Result<Box<dyn SocketOps>, Errno> {
313    match domain {
314        SocketDomain::Unix => Ok(Box::new(UnixSocket::new(socket_type))),
315        SocketDomain::Vsock => Ok(Box::new(VsockSocket::new(socket_type))),
316        SocketDomain::Inet | SocketDomain::Inet6 => {
317            // Follow Linux, and require CAP_NET_RAW to create raw sockets.
318            // See https://man7.org/linux/man-pages/man7/raw.7.html.
319            if socket_type == SocketType::Raw {
320                security::check_task_capable(current_task, CAP_NET_RAW)?;
321            }
322            Ok(Box::new(ZxioBackedSocket::new(
323                locked,
324                current_task,
325                domain,
326                socket_type,
327                protocol,
328            )?))
329        }
330        SocketDomain::Netlink => {
331            let netlink_family = NetlinkFamily::from_raw(protocol.as_raw());
332            new_netlink_socket(current_task.kernel(), socket_type, netlink_family)
333        }
334        SocketDomain::Packet => {
335            // Follow Linux, and require CAP_NET_RAW to create packet sockets.
336            // See https://man7.org/linux/man-pages/man7/packet.7.html.
337            security::check_task_capable(current_task, CAP_NET_RAW)?;
338            Ok(Box::new(ZxioBackedSocket::new(
339                locked,
340                current_task,
341                domain,
342                socket_type,
343                protocol,
344            )?))
345        }
346        SocketDomain::Key => {
347            track_stub!(
348                TODO("https://fxbug.dev/323365389"),
349                "Returning a UnixSocket instead of a KeySocket"
350            );
351            Ok(Box::new(UnixSocket::new(SocketType::Datagram)))
352        }
353        SocketDomain::Qipcrtr => Ok(Box::new(QipcrtrSocket::new(socket_type))),
354    }
355}
356
357#[derive(Debug)]
358pub enum SockOptValue {
359    Value(Vec<u8>),
360    User(UserBuffer),
361}
362
363impl From<Vec<u8>> for SockOptValue {
364    fn from(buffer: Vec<u8>) -> Self {
365        Self::Value(buffer)
366    }
367}
368
369impl From<UserBuffer> for SockOptValue {
370    fn from(buffer: UserBuffer) -> Self {
371        Self::User(buffer)
372    }
373}
374
375impl SockOptValue {
376    pub fn len(&self) -> usize {
377        match self {
378            Self::Value(buffer) => buffer.len(),
379            Self::User(user_buffer) => user_buffer.length,
380        }
381    }
382
383    pub fn read<T: FromBytes>(&self, current_task: &CurrentTask) -> Result<T, Errno> {
384        match self {
385            Self::Value(buffer) => {
386                T::read_from_prefix(&buffer).map_err(|_| errno!(EINVAL)).map(|(v, _)| v)
387            }
388            Self::User(user_buffer) => {
389                current_task.read_object::<T>(user_buffer.clone().try_into()?)
390            }
391        }
392    }
393
394    pub fn read_bytes(
395        &self,
396        current_task: &CurrentTask,
397        max_bytes: usize,
398    ) -> Result<Vec<u8>, Errno> {
399        match self {
400            Self::Value(buffer) => {
401                let bytes = std::cmp::min(max_bytes, buffer.len());
402                Ok(buffer[..bytes].to_owned())
403            }
404            Self::User(user_buffer) => {
405                let bytes = std::cmp::min(max_bytes, user_buffer.length);
406                current_task
407                    .read_buffer(&UserBuffer { address: user_buffer.address, length: bytes })
408            }
409        }
410    }
411
412    pub fn to_vec(self, current_task: &CurrentTask) -> Result<Vec<u8>, Errno> {
413        match self {
414            Self::Value(buffer) => Ok(buffer),
415            Self::User(user_buffer) => current_task.read_buffer(&user_buffer),
416        }
417    }
418}
419
420// Trait used to provide `read_from_sockopt_value` for `MappingMultiArchUserRef`.
421pub trait ReadFromSockOptValue {
422    type Result;
423    fn read_from_sockopt_value(
424        current_task: &CurrentTask,
425        buffer: &SockOptValue,
426    ) -> Result<Self::Result, Errno>;
427}
428
429impl<T, T64, T32> ReadFromSockOptValue for MappingMultiArchUserRef<T, T64, T32>
430where
431    T64: FromBytes + TryInto<T>,
432    T32: FromBytes + TryInto<T>,
433{
434    type Result = T;
435    fn read_from_sockopt_value(
436        current_task: &CurrentTask,
437        buffer: &SockOptValue,
438    ) -> Result<T, Errno> {
439        match buffer {
440            SockOptValue::Value(buffer) => {
441                Self::read_from_prefix(current_task, &buffer).map_err(|_| errno!(EINVAL))
442            }
443            SockOptValue::User(user_buffer) => {
444                let user_ref = Self::new_with_ref(current_task, user_buffer.clone())?;
445                current_task.read_multi_arch_object(user_ref)
446            }
447        }
448    }
449}
450
451impl Socket {
452    /// Creates a new unbound socket.
453    ///
454    /// # Parameters
455    /// - `domain`: The domain of the socket (e.g., `AF_UNIX`).
456    pub fn new<L>(
457        locked: &mut Locked<L>,
458        current_task: &CurrentTask,
459        domain: SocketDomain,
460        socket_type: SocketType,
461        protocol: SocketProtocol,
462        kernel_private: bool,
463    ) -> Result<SocketHandle, Errno>
464    where
465        L: LockEqualOrBefore<FileOpsCore>,
466    {
467        let protocol = resolve_protocol(domain, socket_type, protocol);
468        // Checking access in `Socket::new()` prevents creating socket handles when not allowed,
469        // while skipping the "create" permission check for accepted sockets created with
470        // `Socket::new_with_ops()` and `Socket::new_with_ops_and_info()`.
471        security::check_socket_create_access(
472            locked,
473            current_task,
474            domain,
475            socket_type,
476            protocol,
477            kernel_private,
478        )?;
479        let ops =
480            create_socket_ops(locked.cast_locked(), current_task, domain, socket_type, protocol)?;
481        Ok(Self::new_with_ops_and_info(ops, domain, socket_type, protocol))
482    }
483
484    pub fn new_with_ops(ops: Box<dyn SocketOps>) -> Result<SocketHandle, Errno> {
485        let (domain, socket_type, protocol) = ops.get_socket_info()?;
486        Ok(Self::new_with_ops_and_info(ops, domain, socket_type, protocol))
487    }
488
489    pub fn new_with_ops_and_info(
490        ops: Box<dyn SocketOps>,
491        domain: SocketDomain,
492        socket_type: SocketType,
493        protocol: SocketProtocol,
494    ) -> SocketHandle {
495        Arc::new(Socket {
496            ops,
497            domain,
498            socket_type,
499            protocol,
500            state: Mutex::default(),
501            security: security::SocketState::default(),
502        })
503    }
504
505    pub(super) fn set_fs_node(&self, node: &FsNodeHandle) {
506        let mut locked_state = self.state.lock();
507        assert!(locked_state.fs_node.is_none());
508        locked_state.fs_node = Some(node.clone());
509    }
510
511    /// Returns the Socket that this FileHandle refers to. If this file is not a socket file,
512    /// returns ENOTSOCK.
513    pub fn get_from_file(file: &FileHandle) -> Result<&SocketHandle, Errno> {
514        let socket_file = file.downcast_file::<SocketFile>().ok_or_else(|| errno!(ENOTSOCK))?;
515        Ok(&socket_file.socket)
516    }
517
518    pub fn downcast_socket<T>(&self) -> Option<&T>
519    where
520        T: 'static,
521    {
522        let ops = &*self.ops;
523        ops.as_any().downcast_ref::<T>()
524    }
525
526    pub fn getsockname<L>(&self, locked: &mut Locked<L>) -> Result<SocketAddress, Errno>
527    where
528        L: LockEqualOrBefore<FileOpsCore>,
529    {
530        self.ops.getsockname(locked.cast_locked::<FileOpsCore>(), self)
531    }
532
533    pub fn getpeername<L>(&self, locked: &mut Locked<L>) -> Result<SocketAddress, Errno>
534    where
535        L: LockEqualOrBefore<FileOpsCore>,
536    {
537        self.ops.getpeername(locked.cast_locked::<FileOpsCore>(), self)
538    }
539
540    pub fn setsockopt<L>(
541        &self,
542        locked: &mut Locked<L>,
543        current_task: &CurrentTask,
544        level: u32,
545        optname: u32,
546        optval: SockOptValue,
547    ) -> Result<(), Errno>
548    where
549        L: LockEqualOrBefore<FileOpsCore>,
550    {
551        let locked = locked.cast_locked::<FileOpsCore>();
552        let read_timeval = || {
553            let timeval = TimeValPtr::read_from_sockopt_value(current_task, &optval)?;
554            let duration = duration_from_timeval(timeval)?;
555            Ok(if duration == zx::MonotonicDuration::default() { None } else { Some(duration) })
556        };
557
558        security::check_socket_setsockopt_access(current_task, self, level, optname)?;
559        match (level, optname) {
560            (SOL_SOCKET, SO_RCVTIMEO) => self.state.lock().receive_timeout = read_timeval()?,
561            (SOL_SOCKET, SO_SNDTIMEO) => self.state.lock().send_timeout = read_timeval()?,
562            _ => self.ops.setsockopt(locked, self, current_task, level, optname, optval)?,
563        }
564        Ok(())
565    }
566
567    pub fn getsockopt<L>(
568        &self,
569        locked: &mut Locked<L>,
570        current_task: &CurrentTask,
571        level: u32,
572        optname: u32,
573        optlen: u32,
574    ) -> Result<Vec<u8>, Errno>
575    where
576        L: LockEqualOrBefore<FileOpsCore>,
577    {
578        let locked = locked.cast_locked::<FileOpsCore>();
579        security::check_socket_getsockopt_access(current_task, self, level, optname)?;
580        let value = match level {
581            SOL_SOCKET => match optname {
582                SO_TYPE => self.socket_type.as_raw().to_ne_bytes().to_vec(),
583                SO_DOMAIN => {
584                    let domain = self.domain.as_raw() as u32;
585                    domain.to_ne_bytes().to_vec()
586                }
587                SO_PROTOCOL => self.protocol.as_raw().to_ne_bytes().to_vec(),
588                SO_RCVTIMEO => {
589                    let duration = self.receive_timeout().unwrap_or_default();
590                    TimeValPtr::into_bytes(current_task, timeval_from_duration(duration))
591                        .map_err(|_| errno!(EINVAL))?
592                }
593                SO_SNDTIMEO => {
594                    let duration = self.send_timeout().unwrap_or_default();
595                    TimeValPtr::into_bytes(current_task, timeval_from_duration(duration))
596                        .map_err(|_| errno!(EINVAL))?
597                }
598                SO_ANDROID_DROP_REASON => {
599                    track_stub!(
600                        TODO("https://fxbug.dev/477273398"),
601                        "Faking SO_ANDROID_DROP_REASON"
602                    );
603                    ANDROID_DROP_REASON_NONE.to_ne_bytes().to_vec()
604                }
605                _ => self.ops.getsockopt(locked, self, current_task, level, optname, optlen)?,
606            },
607            _ => self.ops.getsockopt(locked, self, current_task, level, optname, optlen)?,
608        };
609        Ok(value)
610    }
611
612    pub fn receive_timeout(&self) -> Option<zx::MonotonicDuration> {
613        self.state.lock().receive_timeout
614    }
615
616    pub fn send_timeout(&self) -> Option<zx::MonotonicDuration> {
617        self.state.lock().send_timeout
618    }
619
620    pub fn ioctl(
621        &self,
622        locked: &mut Locked<Unlocked>,
623        file: &FileObject,
624        current_task: &CurrentTask,
625        request: u32,
626        arg: SyscallArg,
627    ) -> Result<SyscallResult, Errno> {
628        let res = super::netlink_ioctl::netlink_ioctl(locked, current_task, request, arg);
629        match &res {
630            Err(e) if e.code == ENOTTY => {}
631            _ => return res,
632        }
633        self.ops.ioctl(locked, self, file, current_task, request, arg)
634    }
635
636    pub fn bind<L>(
637        &self,
638        locked: &mut Locked<L>,
639        current_task: &CurrentTask,
640        socket_address: SocketAddress,
641    ) -> Result<(), Errno>
642    where
643        L: LockEqualOrBefore<FileOpsCore>,
644    {
645        self.ops.bind(locked.cast_locked::<FileOpsCore>(), self, current_task, socket_address)
646    }
647
648    pub fn listen<L>(
649        &self,
650        locked: &mut Locked<L>,
651        current_task: &CurrentTask,
652        backlog: i32,
653    ) -> Result<(), Errno>
654    where
655        L: LockEqualOrBefore<FileOpsCore>,
656    {
657        security::check_socket_listen_access(current_task, self, backlog)?;
658        let max_connections =
659            current_task.kernel().system_limits.socket.max_connections.load(Ordering::Relaxed);
660        let backlog = std::cmp::min(backlog, max_connections);
661        let credentials = current_task.current_ucred();
662        self.ops.listen(locked.cast_locked::<FileOpsCore>(), self, backlog, credentials)
663    }
664
665    pub fn accept<L>(
666        &self,
667        locked: &mut Locked<L>,
668        current_task: &CurrentTask,
669    ) -> Result<SocketHandle, Errno>
670    where
671        L: LockEqualOrBefore<FileOpsCore>,
672    {
673        self.ops.accept(locked.cast_locked::<FileOpsCore>(), self, current_task)
674    }
675
676    pub fn read<L>(
677        &self,
678        locked: &mut Locked<L>,
679        current_task: &CurrentTask,
680        data: &mut dyn OutputBuffer,
681        flags: SocketMessageFlags,
682    ) -> Result<MessageReadInfo, Errno>
683    where
684        L: LockEqualOrBefore<FileOpsCore>,
685    {
686        security::check_socket_recvmsg_access(current_task, self)?;
687        let locked = locked.cast_locked::<FileOpsCore>();
688        self.ops.read(locked, self, current_task, data, flags)
689    }
690
691    pub fn write<L>(
692        &self,
693        locked: &mut Locked<L>,
694        current_task: &CurrentTask,
695        data: &mut dyn InputBuffer,
696        dest_address: &mut Option<SocketAddress>,
697        ancillary_data: &mut Vec<AncillaryData>,
698    ) -> Result<usize, Errno>
699    where
700        L: LockEqualOrBefore<FileOpsCore>,
701    {
702        security::check_socket_sendmsg_access(current_task, self)?;
703        let locked = locked.cast_locked::<FileOpsCore>();
704        self.ops.write(locked, self, current_task, data, dest_address, ancillary_data)
705    }
706
707    pub fn wait_async<L>(
708        &self,
709        locked: &mut Locked<L>,
710        current_task: &CurrentTask,
711        waiter: &Waiter,
712        events: FdEvents,
713        handler: EventHandler,
714    ) -> WaitCanceler
715    where
716        L: LockEqualOrBefore<FileOpsCore>,
717    {
718        let locked = locked.cast_locked::<FileOpsCore>();
719        self.ops.wait_async(locked, self, current_task, waiter, events, handler)
720    }
721
722    pub fn query_events<L>(
723        &self,
724        locked: &mut Locked<L>,
725        current_task: &CurrentTask,
726    ) -> Result<FdEvents, Errno>
727    where
728        L: LockEqualOrBefore<FileOpsCore>,
729    {
730        self.ops.query_events(locked.cast_locked::<FileOpsCore>(), self, current_task)
731    }
732
733    pub fn shutdown<L>(
734        &self,
735        locked: &mut Locked<L>,
736        current_task: &CurrentTask,
737        how: SocketShutdownFlags,
738    ) -> Result<(), Errno>
739    where
740        L: LockEqualOrBefore<FileOpsCore>,
741    {
742        security::check_socket_shutdown_access(current_task, self, how)?;
743        self.ops.shutdown(locked.cast_locked::<FileOpsCore>(), self, how)
744    }
745
746    pub fn close<L>(&self, locked: &mut Locked<L>, current_task: &CurrentTask)
747    where
748        L: LockEqualOrBefore<FileOpsCore>,
749    {
750        self.ops.close(locked.cast_locked::<FileOpsCore>(), current_task, self)
751    }
752
753    pub fn to_handle(
754        &self,
755        _file: &FileObject,
756        current_task: &CurrentTask,
757    ) -> Result<Option<zx::NullableHandle>, Errno> {
758        self.ops.to_handle(self, current_task)
759    }
760
761    /// Returns the [`crate::vfs::FsNode`] unique to this `Socket`.
762    // TODO: https://fxbug.dev/414583985 - Create `FsNode` at `Socket` creation and make this
763    // infallible.
764    pub fn fs_node(&self) -> Option<FsNodeHandle> {
765        self.state.lock().fs_node.clone()
766    }
767}
768
769impl DowncastedFile<'_, SocketFile> {
770    pub fn connect<L>(
771        self,
772        locked: &mut Locked<L>,
773        current_task: &CurrentTask,
774        peer: SocketPeer,
775    ) -> Result<(), Errno>
776    where
777        L: LockEqualOrBefore<FileOpsCore>,
778    {
779        security::check_socket_connect_access(current_task, self, &peer)?;
780        self.socket.ops.connect(locked.cast_locked(), &self.socket, current_task, peer)
781    }
782}
783
784pub struct AcceptQueue {
785    pub sockets: VecDeque<SocketHandle>,
786    pub backlog: usize,
787}
788
789impl AcceptQueue {
790    pub fn new(backlog: usize) -> AcceptQueue {
791        AcceptQueue { sockets: VecDeque::with_capacity(backlog), backlog }
792    }
793
794    pub fn set_backlog(&mut self, backlog: usize) -> Result<(), Errno> {
795        if self.sockets.len() > backlog {
796            return error!(EINVAL);
797        }
798        self.backlog = backlog;
799        Ok(())
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806    use crate::testing::{map_memory, spawn_kernel_and_run};
807    use crate::vfs::{UnixControlData, VecInputBuffer, VecOutputBuffer};
808    use starnix_uapi::SO_PASSCRED;
809    use starnix_uapi::user_address::{UserAddress, UserRef};
810
811    #[fuchsia::test]
812    async fn test_dgram_socket() {
813        spawn_kernel_and_run(async |locked, current_task| {
814            let bind_address = SocketAddress::Unix(b"dgram_test".into());
815            let rec_dgram = Socket::new(
816                locked,
817                &current_task,
818                SocketDomain::Unix,
819                SocketType::Datagram,
820                SocketProtocol::default(),
821                /* kernel_private = */ false,
822            )
823            .expect("Failed to create socket.");
824            let passcred: u32 = 1;
825            let opt_size = std::mem::size_of::<u32>();
826            let user_address =
827                map_memory(locked, &current_task, UserAddress::default(), opt_size as u64);
828            let opt_ref = UserRef::<u32>::new(user_address);
829            current_task.write_object(opt_ref, &passcred).unwrap();
830            let opt_buf = UserBuffer { address: user_address, length: opt_size };
831            rec_dgram
832                .setsockopt(locked, &current_task, SOL_SOCKET, SO_PASSCRED, opt_buf.into())
833                .unwrap();
834
835            rec_dgram
836                .bind(locked, &current_task, bind_address)
837                .expect("failed to bind datagram socket");
838
839            let xfer_value: u64 = 1234567819;
840            let xfer_bytes = xfer_value.to_ne_bytes();
841
842            let send = Socket::new(
843                locked,
844                &current_task,
845                SocketDomain::Unix,
846                SocketType::Datagram,
847                SocketProtocol::default(),
848                /* kernel_private = */ false,
849            )
850            .expect("Failed to connect socket.");
851            send.ops
852                .connect(
853                    locked.cast_locked(),
854                    &send,
855                    &current_task,
856                    SocketPeer::Handle(rec_dgram.clone()),
857                )
858                .unwrap();
859            let mut source_iter = VecInputBuffer::new(&xfer_bytes);
860            send.write(locked, &current_task, &mut source_iter, &mut None, &mut vec![]).unwrap();
861            assert_eq!(source_iter.available(), 0);
862            // Previously, this would cause the test to fail,
863            // because rec_dgram was shut down.
864            send.close(locked, &current_task);
865
866            let mut rec_buffer = VecOutputBuffer::new(8);
867            let read_info = rec_dgram
868                .read(locked, &current_task, &mut rec_buffer, SocketMessageFlags::empty())
869                .unwrap();
870            assert_eq!(read_info.bytes_read, xfer_bytes.len());
871            assert_eq!(rec_buffer.data(), xfer_bytes);
872            assert_eq!(1, read_info.ancillary_data.len());
873            assert_eq!(
874                read_info.ancillary_data[0],
875                AncillaryData::Unix(UnixControlData::Credentials(uapi::ucred {
876                    pid: current_task.get_pid(),
877                    uid: 0,
878                    gid: 0
879                }))
880            );
881
882            rec_dgram.close(locked, &current_task);
883        })
884        .await;
885    }
886}