starnix_core/vfs/socket/
syscalls.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::bpf::attachments::SetSockOptProgramResult;
6use crate::mm::{IOVecPtr, MemoryAccessor, MemoryAccessorExt};
7use crate::security;
8use crate::syscalls::time::TimeSpecPtr;
9use crate::task::{CurrentTask, IpTables, Task, WaitCallback, Waiter};
10use crate::vfs::buffers::{
11    AncillaryData, ControlMsg, UserBuffersInputBuffer, UserBuffersOutputBuffer,
12};
13use crate::vfs::socket::{
14    SA_FAMILY_SIZE, SA_STORAGE_SIZE, Socket, SocketAddress, SocketDomain, SocketFile,
15    SocketMessageFlags, SocketPeer, SocketProtocol, SocketShutdownFlags, SocketType, UnixSocket,
16    resolve_unix_socket_address,
17};
18use crate::vfs::{FdFlags, FdNumber, FileHandle, FsString, LookupContext};
19use starnix_logging::{log_trace, track_stub};
20use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Unlocked};
21use starnix_types::augmented::Augmented;
22use starnix_types::time::duration_from_timespec;
23use starnix_types::user_buffer::{UserBuffer, UserBuffers};
24use starnix_uapi::auth::CAP_NET_BIND_SERVICE;
25use starnix_uapi::errors::{EEXIST, EINPROGRESS, Errno};
26use starnix_uapi::file_mode::FileMode;
27use starnix_uapi::math::round_up_to_increment;
28use starnix_uapi::open_flags::OpenFlags;
29use starnix_uapi::user_address::{
30    ArchSpecific, MappingMultiArchUserRef, MultiArchUserRef, UserAddress, UserRef,
31};
32use starnix_uapi::user_value::UserValue;
33use starnix_uapi::vfs::FdEvents;
34use starnix_uapi::{
35    MSG_CTRUNC, MSG_DONTWAIT, MSG_TRUNC, MSG_WAITFORONE, SHUT_RD, SHUT_RDWR, SHUT_WR, SOCK_CLOEXEC,
36    SOCK_NONBLOCK, UIO_MAXIOV, errno, error, socklen_t, uapi,
37};
38use std::ops::DerefMut;
39
40uapi::check_arch_independent_layout! {
41    socklen_t {}
42}
43
44/// A `msghdr` can be augmented with a `UserBuffer`. In that case, the `UserBuffer` is used for
45/// the I/O, instead of the `iovec` fields from the `msghdr`.
46pub type WithAlternateBuffer<T> = Augmented<T, UserBuffer>;
47pub type MsgHdrPtr = MappingMultiArchUserRef<MsgHdr, uapi::msghdr, uapi::arch32::msghdr>;
48
49#[derive(Debug, Clone)]
50pub struct MsgHdr {
51    pub name: UserAddress,
52    pub name_len: socklen_t,
53    pub iov: IOVecPtr,
54    pub iovlen: UserValue<usize>,
55    pub control: UserAddress,
56    pub control_len: usize,
57    pub flags: u32,
58}
59
60/// A reference to a `msghdr`.
61///
62/// This enum is used to abstract over whether the `msghdr` is in user memory (and needs to be
63/// read) or has been constructed in the kernel. This is used by `io_uring` to provide a buffer
64/// for `recvmsg`.
65#[derive(Debug, Clone)]
66pub enum MsgHdrRef {
67    Ptr(MsgHdrPtr),
68    Value(WithAlternateBuffer<MsgHdr>),
69}
70
71impl From<MsgHdrPtr> for MsgHdrRef {
72    fn from(ptr: MsgHdrPtr) -> Self {
73        Self::Ptr(ptr)
74    }
75}
76
77impl From<WithAlternateBuffer<MsgHdr>> for MsgHdrRef {
78    fn from(value: WithAlternateBuffer<MsgHdr>) -> Self {
79        Self::Value(value)
80    }
81}
82
83pub type MMsgHdrPtr = MappingMultiArchUserRef<MMsgHdr, uapi::mmsghdr, uapi::arch32::mmsghdr>;
84
85pub struct MMsgHdr {
86    hdr: MsgHdr,
87    len: usize,
88}
89
90uapi::arch_map_data! {
91    BidiTryFrom<MsgHdr, msghdr> {
92        name = msg_name;
93        name_len = msg_namelen;
94        iov = msg_iov;
95        iovlen = msg_iovlen;
96        control = msg_control;
97        control_len = msg_controllen;
98        flags = msg_flags;
99    }
100
101    BidiTryFrom<MMsgHdr, mmsghdr> {
102        hdr = msg_hdr;
103        len = msg_len;
104    }
105}
106
107pub type CMsgHdrPtr = MultiArchUserRef<uapi::cmsghdr, uapi::arch32::cmsghdr>;
108
109pub fn sys_socket(
110    locked: &mut Locked<Unlocked>,
111    current_task: &CurrentTask,
112    domain: u32,
113    socket_type: u32,
114    protocol: u32,
115) -> Result<FdNumber, Errno> {
116    let flags = socket_type & (SOCK_NONBLOCK | SOCK_CLOEXEC);
117    let domain = parse_socket_domain(domain)?;
118    let socket_type = parse_socket_type(domain, socket_type)?;
119    // Should we use parse_socket_protocol here?
120    let protocol = SocketProtocol::from_raw(protocol);
121    let open_flags = socket_flags_to_open_flags(flags);
122    let socket_file = SocketFile::new_socket(
123        locked,
124        current_task,
125        domain,
126        socket_type,
127        open_flags,
128        protocol,
129        /*kernel_private=*/ false,
130    )?;
131
132    let fd_flags = socket_flags_to_fd_flags(flags);
133    let fd = current_task.add_file(locked, socket_file, fd_flags)?;
134    Ok(fd)
135}
136
137fn socket_flags_to_open_flags(flags: u32) -> OpenFlags {
138    OpenFlags::RDWR
139        | if flags & SOCK_NONBLOCK != 0 { OpenFlags::NONBLOCK } else { OpenFlags::empty() }
140}
141
142fn socket_flags_to_fd_flags(flags: u32) -> FdFlags {
143    if flags & SOCK_CLOEXEC != 0 { FdFlags::CLOEXEC } else { FdFlags::empty() }
144}
145
146fn parse_socket_domain(domain: u32) -> Result<SocketDomain, Errno> {
147    SocketDomain::from_raw(domain.try_into().map_err(|_| errno!(EAFNOSUPPORT))?).ok_or_else(|| {
148        track_stub!(TODO("https://fxbug.dev/322875074"), "parse socket domain", domain);
149        errno!(EAFNOSUPPORT)
150    })
151}
152
153fn parse_socket_type(domain: SocketDomain, socket_type: u32) -> Result<SocketType, Errno> {
154    let socket_type = SocketType::from_raw(socket_type & 0xf).ok_or_else(|| {
155        track_stub!(TODO("https://fxbug.dev/322875418"), "parse socket type", socket_type);
156        errno!(EINVAL)
157    })?;
158    // For AF_UNIX, SOCK_RAW sockets are treated as if they were SOCK_DGRAM.
159    Ok(if domain == SocketDomain::Unix && socket_type == SocketType::Raw {
160        SocketType::Datagram
161    } else {
162        socket_type
163    })
164}
165
166fn parse_socket_protocol(
167    domain: SocketDomain,
168    socket_type: SocketType,
169    protocol: u32,
170) -> Result<SocketProtocol, Errno> {
171    let protocol = SocketProtocol::from_raw(protocol);
172    if domain == SocketDomain::Inet {
173        match (socket_type, protocol) {
174            (SocketType::Raw, _) => {
175                // Should we have different behavior error when called by root?
176                return error!(EPROTONOSUPPORT);
177            }
178            (SocketType::Datagram, SocketProtocol::UDP) => (),
179            (SocketType::Datagram, _) => return error!(EPROTONOSUPPORT),
180            (SocketType::Stream, SocketProtocol::TCP) => (),
181            (SocketType::Stream, _) => return error!(EPROTONOSUPPORT),
182            _ => (),
183        }
184    }
185    Ok(protocol)
186}
187
188fn parse_socket_address(
189    task: &Task,
190    user_socket_address: UserAddress,
191    user_address_length: usize,
192) -> Result<SocketAddress, Errno> {
193    if user_address_length < SA_FAMILY_SIZE || user_address_length > SA_STORAGE_SIZE {
194        return error!(EINVAL);
195    }
196
197    let address = task.read_memory_to_vec(user_socket_address, user_address_length)?;
198
199    SocketAddress::from_bytes(address)
200}
201
202fn maybe_parse_socket_address(
203    task: &Task,
204    user_socket_address: UserAddress,
205    user_address_length: usize,
206) -> Result<Option<SocketAddress>, Errno> {
207    if user_address_length > i32::MAX as usize {
208        return error!(EINVAL);
209    }
210    Ok(if user_socket_address.is_null() {
211        None
212    } else {
213        Some(parse_socket_address(task, user_socket_address, user_address_length)?)
214    })
215}
216
217// See "Autobind feature" section of https://man7.org/linux/man-pages/man7/unix.7.html
218fn generate_autobind_address() -> FsString {
219    let mut bytes = [0u8; 4];
220    zx::cprng_draw(&mut bytes);
221    let value = u32::from_ne_bytes(bytes) & 0xFFFFF;
222    format!("\0{value:05x}").into()
223}
224
225pub fn sys_bind(
226    locked: &mut Locked<Unlocked>,
227    current_task: &CurrentTask,
228    fd: FdNumber,
229    user_socket_address: UserAddress,
230    user_address_length: usize,
231) -> Result<(), Errno> {
232    let file = current_task.files.get(fd)?;
233    let socket = Socket::get_from_file(&file)?;
234    let address = parse_socket_address(current_task, user_socket_address, user_address_length)?;
235    if !address.valid_for_domain(socket.domain) {
236        return match socket.domain {
237            SocketDomain::Unix
238            | SocketDomain::Vsock
239            | SocketDomain::Inet6
240            | SocketDomain::Netlink
241            | SocketDomain::Key
242            | SocketDomain::Packet
243            | SocketDomain::Qipcrtr => error!(EINVAL),
244            SocketDomain::Inet => error!(EAFNOSUPPORT),
245        };
246    }
247    if let Some(port) = address.maybe_inet_port() {
248        // See <https://man7.org/linux/man-pages/man7/ip.7.html>:
249        //
250        //   The port numbers below 1024 are called privileged ports (or
251        //   sometimes: reserved ports).  Only a privileged process (on Linux:
252        //   a process that has the CAP_NET_BIND_SERVICE capability in the
253        //   user namespace governing its network namespace) may bind(2) to
254        //   these sockets.
255        if port != 0 && port < 1024 {
256            security::check_task_capable(current_task, CAP_NET_BIND_SERVICE)
257                .map_err(|_| errno!(EACCES))?;
258        }
259    }
260    security::check_socket_bind_access(current_task, socket, &address)?;
261    match address {
262        SocketAddress::Unspecified => return error!(EINVAL),
263        SocketAddress::Unix(mut name) => {
264            if name.is_empty() {
265                // If the name is empty, then we're supposed to generate an
266                // autobind address, which is always abstract.
267                name = generate_autobind_address();
268            }
269            // If there is a null byte at the start of the sun_path, then the
270            // address is abstract.
271            if name[0] == b'\0' {
272                current_task.abstract_socket_namespace.bind(locked, current_task, name, socket)?;
273            } else {
274                let mode = file.node().info().mode;
275                let mode = current_task.fs().apply_umask(mode).with_type(FileMode::IFSOCK);
276                let (parent, basename) = current_task.lookup_parent_at(
277                    locked,
278                    &mut LookupContext::default(),
279                    FdNumber::AT_FDCWD,
280                    name.as_ref(),
281                )?;
282
283                parent
284                    .bind_socket(
285                        locked,
286                        current_task,
287                        basename,
288                        socket.clone(),
289                        SocketAddress::Unix(name.clone()),
290                        mode,
291                    )
292                    .map_err(|errno| if errno == EEXIST { errno!(EADDRINUSE) } else { errno })?;
293            }
294        }
295        SocketAddress::Vsock { port, .. } => {
296            current_task.abstract_vsock_namespace.bind(locked, current_task, port, socket)?;
297        }
298        SocketAddress::Inet(_)
299        | SocketAddress::Inet6(_)
300        | SocketAddress::Netlink(_)
301        | SocketAddress::Packet(_)
302        | SocketAddress::Qipcrtr(_) => socket.bind(locked, current_task, address)?,
303    }
304
305    Ok(())
306}
307
308pub fn sys_listen(
309    locked: &mut Locked<Unlocked>,
310    current_task: &CurrentTask,
311    fd: FdNumber,
312    backlog: i32,
313) -> Result<(), Errno> {
314    let file = current_task.files.get(fd)?;
315    let socket = Socket::get_from_file(&file)?;
316    socket.listen(locked, current_task, backlog)?;
317    Ok(())
318}
319
320pub fn sys_accept(
321    locked: &mut Locked<Unlocked>,
322    current_task: &CurrentTask,
323    fd: FdNumber,
324    user_socket_address: UserAddress,
325    user_address_length: UserRef<socklen_t>,
326) -> Result<FdNumber, Errno> {
327    sys_accept4(locked, current_task, fd, user_socket_address, user_address_length, 0)
328}
329
330pub fn sys_accept4(
331    locked: &mut Locked<Unlocked>,
332    current_task: &CurrentTask,
333    fd: FdNumber,
334    user_socket_address: UserAddress,
335    user_address_length: UserRef<socklen_t>,
336    flags: u32,
337) -> Result<FdNumber, Errno> {
338    let file = current_task.files.get(fd)?;
339    let listening_socket = Socket::get_from_file(&file)?;
340    let accepted_socket = file.blocking_op(
341        locked,
342        current_task,
343        FdEvents::POLLIN | FdEvents::POLLHUP,
344        None,
345        |locked| listening_socket.accept(locked, current_task),
346    )?;
347
348    if !user_socket_address.is_null() {
349        let address_bytes = accepted_socket.getpeername(locked)?.to_bytes();
350        write_socket_address(
351            current_task,
352            user_socket_address,
353            user_address_length,
354            &address_bytes,
355        )?;
356    }
357
358    let open_flags = socket_flags_to_open_flags(flags);
359    let accepted_socket_file = SocketFile::from_socket(
360        locked,
361        current_task,
362        accepted_socket,
363        open_flags,
364        /* kernel_private= */ false,
365    )?;
366    let listening_socket = SocketFile::get_from_file(&file)?;
367    let accepted_socket = SocketFile::get_from_file(&accepted_socket_file)?;
368    security::socket_accept(current_task, listening_socket, accepted_socket)?;
369    let fd_flags = if flags & SOCK_CLOEXEC != 0 { FdFlags::CLOEXEC } else { FdFlags::empty() };
370    let accepted_fd = current_task.add_file(locked, accepted_socket_file, fd_flags)?;
371    Ok(accepted_fd)
372}
373
374pub fn sys_connect(
375    locked: &mut Locked<Unlocked>,
376    current_task: &CurrentTask,
377    fd: FdNumber,
378    user_socket_address: UserAddress,
379    user_address_length: usize,
380) -> Result<(), Errno> {
381    let client = current_task.files.get(fd)?;
382    let client = SocketFile::get_from_file(&client)?;
383    let address = parse_socket_address(current_task, user_socket_address, user_address_length)?;
384    let peer = match address {
385        SocketAddress::Unspecified => return error!(EAFNOSUPPORT),
386        SocketAddress::Unix(ref name) => {
387            log_trace!("connect to unix socket named \"{name}\"");
388            if name.is_empty() {
389                return error!(ECONNREFUSED);
390            }
391            SocketPeer::Handle(resolve_unix_socket_address(locked, current_task, name.as_ref())?)
392        }
393        // TODO(https://fxbug.dev/445433238): Connect not available for AF_VSOCK
394        SocketAddress::Vsock { .. } => return error!(ENOSYS),
395        SocketAddress::Inet(ref addr) | SocketAddress::Inet6(ref addr) => {
396            log_trace!("connect to inet socket named {:?}", addr);
397            SocketPeer::Address(address)
398        }
399        SocketAddress::Netlink(_) => SocketPeer::Address(address),
400        SocketAddress::Packet(ref addr) => {
401            log_trace!("connect to packet socket named {:?}", addr);
402            SocketPeer::Address(address)
403        }
404        SocketAddress::Qipcrtr(ref addr) => {
405            log_trace!("connect to qipcrtr socket named {:?}", addr);
406            SocketPeer::Address(address)
407        }
408    };
409    let result = client.connect(locked, current_task, peer.clone());
410
411    if client.file().is_non_blocking() {
412        return result;
413    }
414
415    match result {
416        // EINPROGRESS may be returned for inet sockets when `connect()` is completed
417        // asynchronously.
418        Err(errno) if errno.code == EINPROGRESS => {
419            let waiter = Waiter::new();
420            client.file().wait_async(
421                locked,
422                current_task,
423                &waiter,
424                FdEvents::POLLOUT,
425                WaitCallback::none(),
426            );
427            if !client.file().query_events(locked, current_task)?.contains(FdEvents::POLLOUT) {
428                waiter.wait(locked, current_task)?;
429            }
430            client.connect(locked, current_task, peer)
431        }
432        // TODO(tbodt): Support blocking when the UNIX domain socket queue fills up. This one's
433        // weird because as far as I can tell, removing a socket from the queue does not actually
434        // trigger FdEvents on anything.
435        result => result,
436    }
437}
438
439fn write_socket_address(
440    current_task: &CurrentTask,
441    user_socket_address: UserAddress,
442    user_address_length: UserRef<socklen_t>,
443    address_bytes: &[u8],
444) -> Result<(), Errno> {
445    let capacity = current_task.read_object(user_address_length)?;
446    if capacity > i32::MAX as socklen_t {
447        return error!(EINVAL);
448    }
449    let length = address_bytes.len() as socklen_t;
450    if length > 0 {
451        let actual = std::cmp::min(length, capacity) as usize;
452        current_task.write_memory(user_socket_address, &address_bytes[..actual])?;
453    }
454    current_task.write_object(user_address_length, &length)?;
455    Ok(())
456}
457
458pub fn sys_getsockname(
459    locked: &mut Locked<Unlocked>,
460    current_task: &CurrentTask,
461    fd: FdNumber,
462    user_socket_address: UserAddress,
463    user_address_length: UserRef<socklen_t>,
464) -> Result<(), Errno> {
465    let file = current_task.files.get(fd)?;
466    let socket = Socket::get_from_file(&file)?;
467    security::check_socket_getsockname_access(current_task, socket)?;
468    let address_bytes = socket.getsockname(locked)?.to_bytes();
469
470    write_socket_address(current_task, user_socket_address, user_address_length, &address_bytes)?;
471
472    Ok(())
473}
474
475pub fn sys_getpeername(
476    locked: &mut Locked<Unlocked>,
477    current_task: &CurrentTask,
478    fd: FdNumber,
479    user_socket_address: UserAddress,
480    user_address_length: UserRef<socklen_t>,
481) -> Result<(), Errno> {
482    let file = current_task.files.get(fd)?;
483    let socket = Socket::get_from_file(&file)?;
484    security::check_socket_getpeername_access(current_task, socket)?;
485    let address_bytes = socket.getpeername(locked)?.to_bytes();
486
487    write_socket_address(current_task, user_socket_address, user_address_length, &address_bytes)?;
488
489    Ok(())
490}
491
492pub fn sys_socketpair(
493    locked: &mut Locked<Unlocked>,
494    current_task: &CurrentTask,
495    domain: u32,
496    socket_type: u32,
497    protocol: u32,
498    user_sockets: UserRef<[FdNumber; 2]>,
499) -> Result<(), Errno> {
500    let flags = socket_type & (SOCK_NONBLOCK | SOCK_CLOEXEC);
501    let domain = parse_socket_domain(domain)?;
502    if !matches!(domain, SocketDomain::Unix | SocketDomain::Inet) {
503        return error!(EAFNOSUPPORT);
504    }
505    let socket_type = parse_socket_type(domain, socket_type)?;
506    let _protocol = parse_socket_protocol(domain, socket_type, protocol)?;
507    if domain != SocketDomain::Unix {
508        return error!(EOPNOTSUPP);
509    }
510    let open_flags = socket_flags_to_open_flags(flags);
511
512    let (left, right) =
513        UnixSocket::new_pair(locked, current_task, domain, socket_type, open_flags)?;
514
515    let fd_flags = socket_flags_to_fd_flags(flags);
516    // TODO: Eventually this will need to allocate two fd numbers (each of which could
517    // potentially fail), and only populate the fd numbers (which can't fail) if both allocations
518    // succeed.
519    let left_fd = current_task.add_file(locked, left, fd_flags)?;
520    let right_fd = current_task.add_file(locked, right, fd_flags)?;
521
522    let fds = [left_fd, right_fd];
523    log_trace!("socketpair -> [{:#x}, {:#x}]", fds[0].raw(), fds[1].raw());
524    current_task.write_object(user_sockets, &fds)?;
525
526    Ok(())
527}
528
529fn read_iovec_from_msghdr(
530    current_task: &CurrentTask,
531    message_header: WithAlternateBuffer<&MsgHdr>,
532) -> Result<UserBuffers, Errno> {
533    if let WithAlternateBuffer::WithAux(_, b) = message_header {
534        return Ok(UserBuffers::from_buf([b]));
535    }
536    let iovec_count = message_header.iovlen;
537
538    // In `CurrentTask::read_iovec()` the same check fails with `EINVAL`. This works for all
539    // syscalls that use `iovec`, except `sendmsg()` and `recvmsg()`, which need to fail with
540    // EMSGSIZE.
541    if iovec_count.raw() > UIO_MAXIOV as usize {
542        return error!(EMSGSIZE);
543    }
544
545    current_task.read_iovec(message_header.iov, iovec_count)
546}
547
548fn recvmsg_internal<L>(
549    locked: &mut Locked<L>,
550    current_task: &CurrentTask,
551    file: &FileHandle,
552    user_message_header: &mut MsgHdrRef,
553    flags: u32,
554    deadline: Option<zx::MonotonicInstant>,
555) -> Result<usize, Errno>
556where
557    L: LockEqualOrBefore<FileOpsCore>,
558{
559    let mut message_header = match *user_message_header {
560        MsgHdrRef::Ptr(ptr) => current_task.read_multi_arch_object(ptr)?.into(),
561        MsgHdrRef::Value(ref value) => value.clone(),
562    };
563    let result = recvmsg_internal_with_header(
564        locked,
565        current_task,
566        file,
567        message_header.as_mut(),
568        flags,
569        deadline,
570    )?;
571    match *user_message_header {
572        MsgHdrRef::Ptr(ptr) => {
573            current_task.write_multi_arch_object(ptr, message_header.extract())?;
574        }
575        MsgHdrRef::Value(ref mut value) => {
576            *value.deref_mut() = message_header.extract();
577        }
578    }
579    Ok(result)
580}
581
582fn recvmsg_internal_with_header<L>(
583    locked: &mut Locked<L>,
584    current_task: &CurrentTask,
585    file: &FileHandle,
586    mut message_header: WithAlternateBuffer<&mut MsgHdr>,
587    flags: u32,
588    deadline: Option<zx::MonotonicInstant>,
589) -> Result<usize, Errno>
590where
591    L: LockEqualOrBefore<FileOpsCore>,
592{
593    let iovec = read_iovec_from_msghdr(current_task, message_header.as_unmut())?;
594
595    let flags = SocketMessageFlags::from_bits(flags).ok_or_else(|| errno!(EINVAL))?;
596    let socket_ops = file.downcast_file::<SocketFile>().unwrap();
597    let info = socket_ops.recvmsg(
598        locked,
599        current_task,
600        file,
601        &mut UserBuffersOutputBuffer::unified_new(current_task, iovec)?,
602        flags,
603        deadline,
604    )?;
605
606    message_header.flags = 0;
607
608    let cmsg_buffer_size = message_header.control_len;
609
610    let mut cmsg_bytes_written = 0;
611    let header_size = CMsgHdrPtr::size_of_object_for(current_task);
612
613    for ancillary_data in info.ancillary_data {
614        if ancillary_data.total_size(current_task) == 0 {
615            // Skip zero-byte ancillary data on the receiving end. Not doing this trips this
616            // assert:
617            // https://cs.android.com/android/platform/superproject/+/master:system/libbase/cmsg.cpp;l=144;drc=15ec2c7a23cda814351a064a345a8270ed8c83ab
618            continue;
619        }
620
621        let expected_size = header_size + ancillary_data.total_size(current_task);
622        let message_bytes = ancillary_data.into_bytes(
623            locked,
624            current_task,
625            flags,
626            cmsg_buffer_size - cmsg_bytes_written,
627        )?;
628
629        // If the message is smaller than expected, set the MSG_CTRUNC flag, so the caller can tell
630        // some of the message is missing.
631        let truncated = message_bytes.len() < expected_size;
632        if truncated {
633            message_header.flags |= MSG_CTRUNC;
634        }
635
636        if message_bytes.len() < header_size {
637            // Can't fit the header, so stop trying to write.
638            break;
639        }
640
641        if !message_bytes.is_empty() {
642            current_task
643                .write_memory((message_header.control + cmsg_bytes_written)?, &message_bytes)?;
644            cmsg_bytes_written += message_bytes.len();
645            if !truncated {
646                cmsg_bytes_written = cmsg_align(current_task, cmsg_bytes_written)?;
647            }
648        }
649    }
650
651    message_header.control_len = cmsg_bytes_written;
652
653    let msg_name = message_header.name;
654    if !msg_name.is_null() {
655        if message_header.name_len > i32::MAX as u32 {
656            return error!(EINVAL);
657        }
658        let bytes = info.address.map(|a| a.to_bytes()).unwrap_or_else(|| vec![]);
659        let num_bytes = std::cmp::min(message_header.name_len as usize, bytes.len());
660        message_header.name_len = bytes.len() as u32;
661        if num_bytes > 0 {
662            current_task.write_memory(msg_name, &bytes[..num_bytes])?;
663        }
664    }
665
666    if info.bytes_read != info.message_length {
667        message_header.flags |= MSG_TRUNC;
668    }
669
670    if flags.contains(SocketMessageFlags::TRUNC) {
671        Ok(info.message_length)
672    } else {
673        Ok(info.bytes_read)
674    }
675}
676
677pub fn sys_recvmsg(
678    locked: &mut Locked<Unlocked>,
679    current_task: &CurrentTask,
680    fd: FdNumber,
681    user_message_header: MsgHdrPtr,
682    flags: u32,
683) -> Result<usize, Errno> {
684    recvmsg_impl(locked, current_task, fd, &mut user_message_header.into(), flags)
685}
686
687/// Implementation of `recvmsg`.
688///
689/// This function is used by `sys_recvmsg`, but can also be called from other parts of the kernel
690/// that need to override the `iovec` from the `msghdr`. For example, when using `io_uring` with
691/// ring buffers.
692pub fn recvmsg_impl(
693    locked: &mut Locked<Unlocked>,
694    current_task: &CurrentTask,
695    fd: FdNumber,
696    user_message_header: &mut MsgHdrRef,
697    flags: u32,
698) -> Result<usize, Errno> {
699    let file = current_task.files.get(fd)?;
700    if !file.node().is_sock() {
701        return error!(ENOTSOCK);
702    }
703    recvmsg_internal(locked, current_task, &file, user_message_header, flags, None)
704}
705
706pub fn sys_recvmmsg(
707    locked: &mut Locked<Unlocked>,
708    current_task: &CurrentTask,
709    fd: FdNumber,
710    user_mmsgvec: MMsgHdrPtr,
711    vlen: u32,
712    mut flags: u32,
713    user_timeout: TimeSpecPtr,
714) -> Result<usize, Errno> {
715    let file = current_task.files.get(fd)?;
716    if !file.node().is_sock() {
717        return error!(ENOTSOCK);
718    }
719
720    if vlen > UIO_MAXIOV {
721        return error!(EINVAL);
722    }
723
724    let deadline = if user_timeout.is_null() {
725        None
726    } else {
727        let ts = current_task.read_multi_arch_object(user_timeout)?;
728        Some(zx::MonotonicInstant::after(duration_from_timespec(ts)?))
729    };
730
731    let mut index = 0usize;
732    while index < vlen as usize {
733        let current_ptr = user_mmsgvec.at(index)?;
734        let mut current_mmsghdr = current_task.read_multi_arch_object(current_ptr)?;
735        match recvmsg_internal_with_header(
736            locked,
737            current_task,
738            &file,
739            (&mut current_mmsghdr.hdr).into(),
740            flags,
741            deadline,
742        ) {
743            Err(error) => {
744                if index == 0 {
745                    return Err(error);
746                }
747                break;
748            }
749            Ok(bytes_read) => {
750                current_mmsghdr.len = bytes_read;
751                current_task.write_multi_arch_object(current_ptr, current_mmsghdr)?;
752            }
753        }
754        index += 1;
755        if flags & MSG_WAITFORONE != 0 {
756            flags |= MSG_DONTWAIT;
757        }
758    }
759    Ok(index)
760}
761
762pub fn sys_recvfrom(
763    locked: &mut Locked<Unlocked>,
764    current_task: &CurrentTask,
765    fd: FdNumber,
766    user_buffer: UserAddress,
767    buffer_length: usize,
768    flags: u32,
769    user_src_address: UserAddress,
770    user_src_address_length: UserRef<socklen_t>,
771) -> Result<usize, Errno> {
772    let file = current_task.files.get(fd)?;
773    if !file.node().is_sock() {
774        return error!(ENOTSOCK);
775    }
776
777    let flags = SocketMessageFlags::from_bits(flags).ok_or_else(|| errno!(EINVAL))?;
778    let socket_ops = file.downcast_file::<SocketFile>().unwrap();
779    let info = socket_ops.recvmsg(
780        locked,
781        current_task,
782        &file,
783        &mut UserBuffersOutputBuffer::unified_new_at(current_task, user_buffer, buffer_length)?,
784        flags,
785        None,
786    )?;
787
788    if !user_src_address.is_null() {
789        let bytes = info.address.map(|a| a.to_bytes()).unwrap_or_else(|| vec![]);
790        write_socket_address(current_task, user_src_address, user_src_address_length, &bytes)?;
791    }
792
793    if flags.contains(SocketMessageFlags::TRUNC) {
794        Ok(info.message_length)
795    } else {
796        Ok(info.bytes_read)
797    }
798}
799
800fn sendmsg_internal<L>(
801    locked: &mut Locked<L>,
802    current_task: &CurrentTask,
803    file: &FileHandle,
804    user_message_header: MsgHdrPtr,
805    flags: u32,
806) -> Result<usize, Errno>
807where
808    L: LockEqualOrBefore<FileOpsCore>,
809{
810    let message_header = current_task.read_multi_arch_object(user_message_header)?;
811    sendmsg_internal_with_header(locked, current_task, file, &message_header, flags)
812}
813
814fn sendmsg_internal_with_header<L>(
815    locked: &mut Locked<L>,
816    current_task: &CurrentTask,
817    file: &FileHandle,
818    message_header: &MsgHdr,
819    flags: u32,
820) -> Result<usize, Errno>
821where
822    L: LockEqualOrBefore<FileOpsCore>,
823{
824    if message_header.name_len > i32::MAX as u32 {
825        return error!(EINVAL);
826    }
827    if message_header.control_len > 20480 {
828        return error!(ENOBUFS);
829    }
830    let dest_address = maybe_parse_socket_address(
831        current_task,
832        message_header.name,
833        message_header.name_len as usize,
834    )?;
835    let iovec = read_iovec_from_msghdr(current_task, message_header.into())?;
836
837    let mut next_message_offset: usize = 0;
838    let mut ancillary_data = Vec::new();
839    let header_size = CMsgHdrPtr::size_of_object_for(current_task);
840    loop {
841        let space = message_header.control_len.saturating_sub(next_message_offset);
842        if space < header_size {
843            break;
844        }
845        let cmsg_ref =
846            CMsgHdrPtr::new(current_task, (message_header.control + next_message_offset)?);
847        let cmsg = current_task.read_multi_arch_object(cmsg_ref)?;
848        // If the message header is not long enough to fit the required fields of the
849        // control data, return EINVAL.
850        if (cmsg.cmsg_len as usize) < header_size {
851            return error!(EINVAL);
852        }
853
854        let data_size = std::cmp::min(cmsg.cmsg_len as usize - header_size, space);
855        let next_data_offset = next_message_offset + header_size;
856        let data = current_task
857            .read_memory_to_vec((message_header.control + next_data_offset)?, data_size)?;
858        next_message_offset += cmsg_align(current_task, header_size + data.len())?;
859        let data = AncillaryData::from_cmsg(
860            current_task,
861            ControlMsg::new(cmsg.cmsg_level, cmsg.cmsg_type, data),
862        )?;
863        if data.total_size(current_task) == 0 {
864            continue;
865        }
866        ancillary_data.push(data);
867    }
868
869    let flags = SocketMessageFlags::from_bits(flags).ok_or_else(|| errno!(EOPNOTSUPP))?;
870    let socket_ops = file.downcast_file::<SocketFile>().unwrap();
871    socket_ops.sendmsg(
872        locked,
873        current_task,
874        file,
875        &mut UserBuffersInputBuffer::unified_new(current_task, iovec)?,
876        dest_address,
877        ancillary_data,
878        flags,
879    )
880}
881
882pub fn sys_sendmsg(
883    locked: &mut Locked<Unlocked>,
884    current_task: &CurrentTask,
885    fd: FdNumber,
886    user_message_header: MsgHdrPtr,
887    flags: u32,
888) -> Result<usize, Errno> {
889    let file = current_task.files.get(fd)?;
890    if !file.node().is_sock() {
891        return error!(ENOTSOCK);
892    }
893    sendmsg_internal(locked, current_task, &file, user_message_header, flags)
894}
895
896pub fn sys_sendmmsg(
897    locked: &mut Locked<Unlocked>,
898    current_task: &CurrentTask,
899    fd: FdNumber,
900    user_mmsgvec: MMsgHdrPtr,
901    mut vlen: u32,
902    flags: u32,
903) -> Result<usize, Errno> {
904    let file = current_task.files.get(fd)?;
905    if !file.node().is_sock() {
906        return error!(ENOTSOCK);
907    }
908
909    // vlen is capped at UIO_MAXIOV.
910    if vlen > UIO_MAXIOV {
911        vlen = UIO_MAXIOV;
912    }
913
914    let mut index = 0usize;
915    while index < vlen as usize {
916        let current_ptr = user_mmsgvec.at(index)?;
917        let mut current_mmsghdr = current_task.read_multi_arch_object(current_ptr)?;
918        match sendmsg_internal_with_header(locked, current_task, &file, &current_mmsghdr.hdr, flags)
919        {
920            Err(error) => {
921                if index == 0 {
922                    return Err(error);
923                }
924                break;
925            }
926            Ok(bytes_read) => {
927                current_mmsghdr.len = bytes_read;
928                current_task.write_multi_arch_object(current_ptr, current_mmsghdr)?;
929            }
930        }
931        index += 1;
932    }
933    Ok(index)
934}
935
936pub fn sys_sendto(
937    locked: &mut Locked<Unlocked>,
938    current_task: &CurrentTask,
939    fd: FdNumber,
940    user_buffer: UserAddress,
941    user_buffer_length: usize,
942    flags: u32,
943    user_dest_address: UserAddress,
944    user_dest_address_length: socklen_t,
945) -> Result<usize, Errno> {
946    let file = current_task.files.get(fd)?;
947    if !file.node().is_sock() {
948        return error!(ENOTSOCK);
949    }
950
951    let dest_address = maybe_parse_socket_address(
952        current_task,
953        user_dest_address,
954        user_dest_address_length as usize,
955    )?;
956    let mut data =
957        UserBuffersInputBuffer::unified_new_at(current_task, user_buffer, user_buffer_length)?;
958
959    let flags = SocketMessageFlags::from_bits(flags).ok_or_else(|| errno!(EOPNOTSUPP))?;
960    let socket_file = file.downcast_file::<SocketFile>().unwrap();
961    socket_file.sendmsg(locked, current_task, &file, &mut data, dest_address, vec![], flags)
962}
963
964pub fn sys_getsockopt(
965    locked: &mut Locked<Unlocked>,
966    current_task: &CurrentTask,
967    fd: FdNumber,
968    level: u32,
969    optname: u32,
970    user_optval: UserAddress,
971    user_optlen: UserRef<socklen_t>,
972) -> Result<(), Errno> {
973    let file = current_task.files.get(fd)?;
974    let socket = Socket::get_from_file(&file)?;
975
976    let optlen = current_task.read_object(user_optlen)? as usize;
977    let optval_buffer_len = optlen;
978    let mut optval = current_task.read_memory_to_vec(user_optval, optlen as usize)?;
979
980    let result = if socket.domain.is_inet() && IpTables::can_handle_getsockopt(level, optname) {
981        current_task.kernel().iptables().getsockopt(
982            locked,
983            current_task,
984            socket,
985            optname,
986            optval.clone(),
987        )
988    } else {
989        socket.getsockopt(locked, current_task, level, optname, optlen as u32)
990    };
991
992    // Even if `getsockopt()` above returned an error we still need to run
993    // the eBPF program - it may handle the error.
994    let (optlen, error) = match result {
995        Ok(new_optval) if new_optval.len() > optval.len() => (optlen, Some(errno!(EINVAL))),
996        Ok(new_optval) => {
997            // Copy the returned value to the buffer, but don't truncate it yet
998            // - this will allow to use the whole buffer in the eBPF program.
999            optval[..new_optval.len()].copy_from_slice(&new_optval);
1000            (new_optval.len(), None)
1001        }
1002        Err(e) => (optlen, Some(e)),
1003    };
1004
1005    let root_cgroup = current_task.kernel().ebpf_state.attachments.root_cgroup();
1006    let (optval, optlen) = root_cgroup.run_getsockopt_prog(
1007        locked.cast_locked(),
1008        current_task,
1009        level,
1010        optname,
1011        optval,
1012        optlen,
1013        error,
1014    )?;
1015
1016    assert!(optlen <= optval_buffer_len);
1017    current_task.write_memory(user_optval, &optval[..optlen])?;
1018    current_task.write_object(user_optlen, &(optlen as u32))?;
1019
1020    Ok(())
1021}
1022
1023pub fn sys_setsockopt(
1024    locked: &mut Locked<Unlocked>,
1025    current_task: &CurrentTask,
1026    fd: FdNumber,
1027    level: u32,
1028    optname: u32,
1029    user_optval: UserAddress,
1030    optlen: socklen_t,
1031) -> Result<(), Errno> {
1032    let file = current_task.files.get(fd)?;
1033    let socket = Socket::get_from_file(&file)?;
1034
1035    let user_opt = UserBuffer { address: user_optval, length: optlen as usize };
1036
1037    // Run eBPF program if any.
1038    let root_cgroup = current_task.kernel().ebpf_state.attachments.root_cgroup();
1039    let optval = match root_cgroup.run_setsockopt_prog(
1040        locked.cast_locked(),
1041        current_task,
1042        level,
1043        optname,
1044        user_opt.into(),
1045    ) {
1046        SetSockOptProgramResult::Allow(value) => value,
1047        SetSockOptProgramResult::Fail(errno) => return Err(errno),
1048        SetSockOptProgramResult::Bypass => return Ok(()), // The option was handled by eBPF.
1049    };
1050
1051    if socket.domain.is_inet() && IpTables::can_handle_setsockopt(level, optname) {
1052        current_task.kernel().iptables().setsockopt(locked, current_task, socket, optname, optval)
1053    } else {
1054        socket.setsockopt(locked, current_task, level, optname, optval)
1055    }
1056}
1057
1058pub fn sys_shutdown(
1059    locked: &mut Locked<Unlocked>,
1060    current_task: &CurrentTask,
1061    fd: FdNumber,
1062    how: u32,
1063) -> Result<(), Errno> {
1064    let file = current_task.files.get(fd)?;
1065    let socket = Socket::get_from_file(&file)?;
1066    let how = match how {
1067        SHUT_RD => SocketShutdownFlags::READ,
1068        SHUT_WR => SocketShutdownFlags::WRITE,
1069        SHUT_RDWR => SocketShutdownFlags::READ | SocketShutdownFlags::WRITE,
1070        _ => return error!(EINVAL),
1071    };
1072    socket.shutdown(locked, current_task, how)?;
1073    Ok(())
1074}
1075
1076pub fn cmsg_align(current_task: &CurrentTask, value: usize) -> Result<usize, Errno> {
1077    let alignment = if current_task.is_arch32() { 4 } else { 8 };
1078    round_up_to_increment(value, alignment)
1079}
1080
1081// Syscalls for arch32 usage
1082#[cfg(target_arch = "aarch64")]
1083mod arch32 {
1084    use crate::task::CurrentTask;
1085    use crate::vfs::FdNumber;
1086    use starnix_sync::{Locked, Unlocked};
1087    use starnix_uapi::errors::Errno;
1088    use starnix_uapi::user_address::UserAddress;
1089
1090    pub use super::{
1091        sys_accept as sys_arch32_accept, sys_accept4 as sys_arch32_accept4,
1092        sys_bind as sys_arch32_bind, sys_getpeername as sys_arch32_getpeername,
1093        sys_getsockname as sys_arch32_getsockname, sys_getsockopt as sys_arch32_getsockopt,
1094        sys_listen as sys_arch32_listen, sys_recvfrom as sys_arch32_recvfrom,
1095        sys_recvmmsg as sys_arch32_recvmmsg, sys_recvmsg as sys_arch32_recvmsg,
1096        sys_sendmsg as sys_arch32_sendmsg, sys_sendto as sys_arch32_sendto,
1097        sys_setsockopt as sys_arch32_setsockopt, sys_shutdown as sys_arch32_shutdown,
1098        sys_socketpair as sys_arch32_socketpair,
1099    };
1100
1101    pub fn sys_arch32_send(
1102        locked: &mut Locked<Unlocked>,
1103        current_task: &CurrentTask,
1104        fd: FdNumber,
1105        user_buffer: UserAddress,
1106        user_buffer_length: usize,
1107        flags: u32,
1108    ) -> Result<usize, Errno> {
1109        super::sys_sendto(
1110            locked,
1111            current_task,
1112            fd,
1113            user_buffer,
1114            user_buffer_length,
1115            flags,
1116            Default::default(),
1117            Default::default(),
1118        )
1119    }
1120
1121    pub fn sys_arch32_recv(
1122        locked: &mut Locked<Unlocked>,
1123        current_task: &CurrentTask,
1124        fd: FdNumber,
1125        user_buffer: UserAddress,
1126        buffer_length: usize,
1127        flags: u32,
1128    ) -> Result<usize, Errno> {
1129        super::sys_recvfrom(
1130            locked,
1131            current_task,
1132            fd,
1133            user_buffer,
1134            buffer_length,
1135            flags,
1136            Default::default(),
1137            Default::default(),
1138        )
1139    }
1140}
1141
1142#[cfg(target_arch = "aarch64")]
1143pub use arch32::*;
1144
1145#[cfg(test)]
1146mod tests {
1147    use super::*;
1148    use crate::testing::spawn_kernel_and_run;
1149    use starnix_uapi::{AF_INET, AF_UNIX, SOCK_STREAM};
1150
1151    #[::fuchsia::test]
1152    async fn test_socketpair_invalid_arguments() {
1153        spawn_kernel_and_run(async |locked, current_task| {
1154            assert_eq!(
1155                sys_socketpair(
1156                    locked,
1157                    current_task,
1158                    AF_INET as u32,
1159                    SOCK_STREAM,
1160                    0,
1161                    UserRef::new(UserAddress::default())
1162                ),
1163                error!(EPROTONOSUPPORT)
1164            );
1165            assert_eq!(
1166                sys_socketpair(
1167                    locked,
1168                    current_task,
1169                    AF_UNIX as u32,
1170                    7,
1171                    0,
1172                    UserRef::new(UserAddress::default())
1173                ),
1174                error!(EINVAL)
1175            );
1176            assert_eq!(
1177                sys_socketpair(
1178                    locked,
1179                    current_task,
1180                    AF_UNIX as u32,
1181                    SOCK_STREAM,
1182                    0,
1183                    UserRef::new(UserAddress::default())
1184                ),
1185                error!(EFAULT)
1186            );
1187        })
1188        .await;
1189    }
1190
1191    #[::fuchsia::test]
1192    fn test_generate_autobind_address() {
1193        let address = generate_autobind_address();
1194        assert_eq!(address.len(), 6);
1195        assert_eq!(address[0], 0);
1196        for byte in address[1..].iter() {
1197            match byte {
1198                b'0'..=b'9' | b'a'..=b'f' => {
1199                    // Ok.
1200                }
1201                bad => {
1202                    panic!("bad byte: {bad}");
1203                }
1204            }
1205        }
1206    }
1207}