Skip to main content

starnix_core/vfs/socket/
socket_netlink.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::security::{self, AuditLogger, AuditMessage, AuditRequest};
6use crate::vfs::socket::{SockOptValue, SocketDomain};
7use futures::channel::mpsc::{
8    UnboundedReceiver, UnboundedSender, {self},
9};
10use linux_uapi::{AUDIT_GET, NETLINK_GET_STRICT_CHK, audit_status};
11use netlink::messaging::{
12    AccessControl, MessageWithPermission, NetlinkContext, NetlinkMessageWithCreds, Permission,
13    Sender, UnparsedNetlinkMessage,
14};
15use netlink::multicast_groups::{
16    InvalidLegacyGroupsError, InvalidModernGroupError, LegacyGroups, ModernGroup,
17    NoMappingFromModernToLegacyGroupError, SingleLegacyGroup,
18};
19use netlink::protocol_family::NetlinkClient;
20use netlink::protocol_family::route::NetlinkRouteClient;
21use netlink::protocol_family::sock_diag::NetlinkSockDiagClient;
22use netlink::{NETLINK_LOG_TAG, NewClientError};
23use netlink_packet_core::{
24    ErrorMessage, NETLINK_HEADER_LEN, NLMSG_ERROR, NetlinkBuffer, NetlinkDeserializable,
25    NetlinkHeader, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
26};
27use netlink_packet_generic::message::EmptyDeserializeOptions as EmptyDeserializeGenlOptions;
28use netlink_packet_route::{RouteNetlinkMessage, RouteNetlinkMessageParseMode};
29use netlink_packet_sock_diag::SockDiagRequest;
30use netlink_packet_sock_diag::message::EmptyDeserializeOptions as EmptyDeserializeSockDiagOptions;
31use netlink_packet_utils::{DecodeError, Emitable as _};
32use starnix_sync::{FileOpsCore, LockEqualOrBefore, Locked, Mutex};
33use std::marker::PhantomData;
34use std::num::{NonZeroI32, NonZeroU32};
35use std::sync::Arc;
36use zerocopy::{FromBytes, IntoBytes};
37
38use crate::device::kobject::{Device, UEventAction, UEventContext};
39use crate::device::{DeviceListener, DeviceListenerKey};
40use crate::task::{CurrentTask, EventHandler, Kernel, WaitCanceler, WaitQueue, Waiter};
41use crate::vfs::buffers::{
42    AncillaryData, InputBuffer, Message, MessageQueue, MessageReadInfo, OutputBuffer,
43    UnixControlData, VecInputBuffer,
44};
45use crate::vfs::socket::{
46    GenericMessage, GenericNetlinkClientHandle, Socket, SocketAddress, SocketHandle,
47    SocketMessageFlags, SocketOps, SocketPeer, SocketShutdownFlags, SocketType,
48};
49use starnix_logging::{log_debug, log_error, log_warn, track_stub};
50use starnix_uapi::auth::{CAP_AUDIT_CONTROL, CAP_AUDIT_WRITE, CAP_NET_ADMIN, Credentials};
51use starnix_uapi::errors::Errno;
52use starnix_uapi::vfs::FdEvents;
53use starnix_uapi::{
54    AF_NETLINK, NETLINK_ADD_MEMBERSHIP, NETLINK_AUDIT, NETLINK_CONNECTOR, NETLINK_CRYPTO,
55    NETLINK_DNRTMSG, NETLINK_DROP_MEMBERSHIP, NETLINK_ECRYPTFS, NETLINK_FIB_LOOKUP,
56    NETLINK_FIREWALL, NETLINK_GENERIC, NETLINK_IP6_FW, NETLINK_ISCSI, NETLINK_KOBJECT_UEVENT,
57    NETLINK_NETFILTER, NETLINK_NFLOG, NETLINK_RDMA, NETLINK_ROUTE, NETLINK_SCSITRANSPORT,
58    NETLINK_SELINUX, NETLINK_SMC, NETLINK_SOCK_DIAG, NETLINK_USERSOCK, NETLINK_XFRM, NLM_F_MULTI,
59    NLMSG_DONE, SO_PASSCRED, SO_PROTOCOL, SO_RCVBUF, SO_RCVBUFFORCE, SO_SNDBUF, SO_SNDBUFFORCE,
60    SO_TIMESTAMP, SOL_SOCKET, errno, error, nlmsghdr, sockaddr_nl, socklen_t, ucred,
61};
62
63// From netlink/socket.go in gVisor.
64pub const SOCKET_MIN_SIZE: usize = 4 << 10;
65pub const SOCKET_DEFAULT_SIZE: usize = 16 * 1024;
66pub const SOCKET_MAX_SIZE: usize = 4 << 20;
67
68// From linux/socket.go in gVisor.
69const SOL_NETLINK: u32 = 270;
70
71pub fn new_netlink_socket(
72    kernel: &Arc<Kernel>,
73    socket_type: SocketType,
74    family: NetlinkFamily,
75) -> Result<Box<dyn SocketOps>, Errno> {
76    log_debug!(tag = NETLINK_LOG_TAG; "Creating {:?} Netlink Socket", family);
77    if socket_type != SocketType::Datagram && socket_type != SocketType::Raw {
78        return error!(ESOCKTNOSUPPORT);
79    }
80
81    let ops: Box<dyn SocketOps> = match family {
82        NetlinkFamily::KobjectUevent => Box::new(UEventNetlinkSocket::default()),
83        NetlinkFamily::Route => Box::new(new_route_socket(kernel)?),
84        NetlinkFamily::Generic => Box::new(GenericNetlinkSocket::new(kernel)?),
85        NetlinkFamily::SockDiag => Box::new(new_sock_diag_socket(kernel)?),
86        NetlinkFamily::Audit => Box::new(AuditNetlinkSocket::new(kernel)?),
87        NetlinkFamily::Usersock
88        | NetlinkFamily::Firewall
89        | NetlinkFamily::Nflog
90        | NetlinkFamily::Xfrm
91        | NetlinkFamily::Selinux
92        | NetlinkFamily::Iscsi
93        | NetlinkFamily::FibLookup
94        | NetlinkFamily::Connector
95        | NetlinkFamily::Netfilter
96        | NetlinkFamily::Ip6Fw
97        | NetlinkFamily::Dnrtmsg
98        | NetlinkFamily::Scsitransport
99        | NetlinkFamily::Ecryptfs
100        | NetlinkFamily::Rdma
101        | NetlinkFamily::Crypto
102        | NetlinkFamily::Smc => Box::new(StubbedNetlinkSocket::new(family)),
103        NetlinkFamily::Invalid => return error!(EINVAL),
104    };
105    Ok(ops)
106}
107
108#[derive(Default, Debug, Clone, PartialEq, Eq)]
109#[repr(C)]
110pub struct NetlinkAddress {
111    pid: u32,
112    groups: u32,
113}
114
115impl NetlinkAddress {
116    pub fn new(pid: u32, groups: u32) -> Self {
117        NetlinkAddress { pid, groups }
118    }
119
120    pub fn set_pid_if_zero(&mut self, pid: i32) {
121        if self.pid == 0 {
122            self.pid = pid as u32;
123        }
124    }
125
126    pub fn to_bytes(&self) -> Vec<u8> {
127        sockaddr_nl { nl_family: AF_NETLINK, nl_pid: self.pid, nl_pad: 0, nl_groups: self.groups }
128            .as_bytes()
129            .to_vec()
130    }
131}
132
133#[derive(Debug, Hash, Eq, PartialEq, Clone)]
134pub enum NetlinkFamily {
135    Invalid,
136    Route,
137    Usersock,
138    Firewall,
139    SockDiag,
140    Nflog,
141    Xfrm,
142    Selinux,
143    Iscsi,
144    Audit,
145    FibLookup,
146    Connector,
147    Netfilter,
148    Ip6Fw,
149    Dnrtmsg,
150    KobjectUevent,
151    Generic,
152    Scsitransport,
153    Ecryptfs,
154    Rdma,
155    Crypto,
156    Smc,
157}
158
159impl NetlinkFamily {
160    pub fn from_raw(family: u32) -> Self {
161        match family {
162            NETLINK_ROUTE => NetlinkFamily::Route,
163            NETLINK_USERSOCK => NetlinkFamily::Usersock,
164            NETLINK_FIREWALL => NetlinkFamily::Firewall,
165            NETLINK_SOCK_DIAG => NetlinkFamily::SockDiag,
166            NETLINK_NFLOG => NetlinkFamily::Nflog,
167            NETLINK_XFRM => NetlinkFamily::Xfrm,
168            NETLINK_SELINUX => NetlinkFamily::Selinux,
169            NETLINK_ISCSI => NetlinkFamily::Iscsi,
170            NETLINK_AUDIT => NetlinkFamily::Audit,
171            NETLINK_FIB_LOOKUP => NetlinkFamily::FibLookup,
172            NETLINK_CONNECTOR => NetlinkFamily::Connector,
173            NETLINK_NETFILTER => NetlinkFamily::Netfilter,
174            NETLINK_IP6_FW => NetlinkFamily::Ip6Fw,
175            NETLINK_DNRTMSG => NetlinkFamily::Dnrtmsg,
176            NETLINK_KOBJECT_UEVENT => NetlinkFamily::KobjectUevent,
177            NETLINK_GENERIC => NetlinkFamily::Generic,
178            NETLINK_SCSITRANSPORT => NetlinkFamily::Scsitransport,
179            NETLINK_ECRYPTFS => NetlinkFamily::Ecryptfs,
180            NETLINK_RDMA => NetlinkFamily::Rdma,
181            NETLINK_CRYPTO => NetlinkFamily::Crypto,
182            NETLINK_SMC => NetlinkFamily::Smc,
183            _ => NetlinkFamily::Invalid,
184        }
185    }
186
187    pub fn as_raw(&self) -> u32 {
188        match self {
189            NetlinkFamily::Route => NETLINK_ROUTE,
190            NetlinkFamily::KobjectUevent => NETLINK_KOBJECT_UEVENT,
191            NetlinkFamily::Audit => NETLINK_AUDIT,
192            _ => 0,
193        }
194    }
195}
196
197struct NetlinkSocketInner {
198    /// The specific type of netlink socket.
199    family: NetlinkFamily,
200
201    /// The [`MessageQueue`] that contains messages from netlink to the client.
202    receive_buffer: MessageQueue,
203
204    /// The socket's send buffer size. Note, This value is only used
205    /// to serve getsockopt calls for `SO_SNDBUF`. It does not yet enforce a
206    /// limit on the number of messages netlink will buffer from the client.
207    /// TODO(https://fxbug.dev/285880057): Limit the size of the send buffer.
208    send_buf_size: usize,
209
210    /// This queue will be notified on reads, writes, disconnects etc.
211    waiters: WaitQueue,
212
213    /// The address of this socket.
214    address: Option<NetlinkAddress>,
215
216    /// See SO_PASSCRED.
217    pub passcred: bool,
218
219    /// See SO_TIMESTAMP.
220    pub timestamp: bool,
221
222    /// See NETLINK_GET_STRICT_CHK.
223    pub strict_chk: bool,
224}
225
226impl NetlinkSocketInner {
227    fn new(family: NetlinkFamily) -> Self {
228        Self {
229            family,
230            receive_buffer: MessageQueue::new(SOCKET_DEFAULT_SIZE),
231            send_buf_size: SOCKET_DEFAULT_SIZE,
232            waiters: WaitQueue::default(),
233            address: None,
234            passcred: false,
235            timestamp: false,
236            strict_chk: false,
237        }
238    }
239
240    fn bind(
241        &mut self,
242        current_task: &CurrentTask,
243        socket_address: SocketAddress,
244    ) -> Result<(), Errno> {
245        if self.address.is_some() {
246            return error!(EINVAL);
247        }
248
249        let netlink_address = match socket_address {
250            SocketAddress::Netlink(mut netlink_address) => {
251                // TODO: Support distinct IDs for processes with multiple netlink sockets.
252                netlink_address.set_pid_if_zero(current_task.get_pid());
253                netlink_address
254            }
255            _ => return error!(EINVAL),
256        };
257
258        self.address = Some(netlink_address);
259        Ok(())
260    }
261
262    fn connect(&mut self, current_task: &CurrentTask, peer: SocketPeer) -> Result<(), Errno> {
263        let address = match peer {
264            SocketPeer::Address(address) => address,
265            _ => return error!(EINVAL),
266        };
267        // Connect is equivalent to bind, but error are ignored.
268        let _ = self.bind(current_task, address);
269        Ok(())
270    }
271
272    fn read_message(&mut self) -> Option<Message> {
273        let message = self.receive_buffer.read_message();
274        if message.is_some() {
275            self.waiters.notify_fd_events(FdEvents::POLLOUT);
276        }
277        message
278    }
279
280    fn read_datagram(
281        &mut self,
282        data: &mut dyn OutputBuffer,
283        flags: SocketMessageFlags,
284    ) -> Result<MessageReadInfo, Errno> {
285        let mut info = if flags.contains(SocketMessageFlags::PEEK) {
286            self.receive_buffer.peek_datagram(data)
287        } else {
288            self.receive_buffer.read_datagram(data)
289        }?;
290        if info.message_length == 0 {
291            return error!(EAGAIN);
292        }
293
294        if self.passcred {
295            track_stub!(TODO("https://fxbug.dev/297373991"), "SCM_CREDENTIALS/SO_PASSCRED");
296            info.ancillary_data.push(AncillaryData::Unix(UnixControlData::unknown_creds()));
297        }
298
299        Ok(info)
300    }
301
302    fn write_to_queue(
303        &mut self,
304        data: &mut dyn InputBuffer,
305        address: Option<NetlinkAddress>,
306        ancillary_data: &mut Vec<AncillaryData>,
307    ) -> Result<usize, Errno> {
308        let socket_address = match address {
309            Some(addr) => Some(SocketAddress::Netlink(addr)),
310            None => self.address.as_ref().map(|addr| SocketAddress::Netlink(addr.clone())),
311        };
312        let bytes_written =
313            self.receive_buffer.write_datagram(data, socket_address, ancillary_data)?;
314        if bytes_written > 0 {
315            self.waiters.notify_fd_events(FdEvents::POLLIN);
316        }
317        Ok(bytes_written)
318    }
319
320    fn wait_async(
321        &mut self,
322        waiter: &Waiter,
323        events: FdEvents,
324        handler: EventHandler,
325    ) -> WaitCanceler {
326        self.waiters.wait_async_fd_events(waiter, events, handler)
327    }
328
329    fn query_events(&self) -> FdEvents {
330        self.receive_buffer.query_events()
331    }
332
333    fn getsockname(&self) -> Result<SocketAddress, Errno> {
334        match &self.address {
335            Some(addr) => Ok(SocketAddress::Netlink(addr.clone())),
336            _ => Ok(SocketAddress::default_for_domain(SocketDomain::Netlink)),
337        }
338    }
339
340    fn getpeername(&self) -> Result<SocketAddress, Errno> {
341        match &self.address {
342            Some(addr) => Ok(SocketAddress::Netlink(addr.clone())),
343            _ => Ok(SocketAddress::default_for_domain(SocketDomain::Netlink)),
344        }
345    }
346
347    fn getsockopt(&self, level: u32, optname: u32) -> Result<Vec<u8>, Errno> {
348        let opt_value = match level {
349            SOL_SOCKET => match optname {
350                SO_PASSCRED => (self.passcred as u32).as_bytes().to_vec(),
351                SO_TIMESTAMP => (self.timestamp as u32).as_bytes().to_vec(),
352                SO_SNDBUF => (self.send_buf_size as socklen_t).to_ne_bytes().to_vec(),
353                SO_RCVBUF => (self.receive_buffer.capacity() as socklen_t).to_ne_bytes().to_vec(),
354                SO_SNDBUFFORCE => (self.send_buf_size as socklen_t).to_ne_bytes().to_vec(),
355                SO_RCVBUFFORCE => {
356                    (self.receive_buffer.capacity() as socklen_t).to_ne_bytes().to_vec()
357                }
358                SO_PROTOCOL => self.family.as_raw().as_bytes().to_vec(),
359                _ => return error!(ENOSYS),
360            },
361            SOL_NETLINK => match optname {
362                NETLINK_GET_STRICT_CHK => (self.strict_chk as u32).as_bytes().to_vec(),
363                _ => return error!(ENOSYS),
364            },
365            _ => vec![],
366        };
367
368        Ok(opt_value)
369    }
370
371    fn setsockopt(
372        &mut self,
373        current_task: &CurrentTask,
374        level: u32,
375        optname: u32,
376        optval: SockOptValue,
377    ) -> Result<(), Errno> {
378        match level {
379            SOL_SOCKET => match optname {
380                SO_SNDBUF => {
381                    let requested_capacity: socklen_t = optval.read(current_task)?;
382                    // SO_SNDBUF doubles the requested capacity to leave space for bookkeeping.
383                    // See https://man7.org/linux/man-pages/man7/socket.7.html
384                    let capacity = usize::try_from(requested_capacity * 2).unwrap_or(usize::MAX);
385                    // TODO(https://fxbug.dev/322907334): Clamp to `wmem_max`.
386                    let capacity = capacity.clamp(SOCKET_MIN_SIZE, SOCKET_MAX_SIZE);
387                    self.send_buf_size = capacity;
388                }
389                SO_SNDBUFFORCE => {
390                    security::check_task_capable(current_task, CAP_NET_ADMIN)?;
391                    let requested_capacity: socklen_t = optval.read(current_task)?;
392                    // SO_SNDBUFFORE doubles the requested capacity to leave space for bookkeeping.
393                    // See https://man7.org/linux/man-pages/man7/socket.7.html
394                    let capacity = usize::try_from(requested_capacity * 2).unwrap_or(usize::MAX);
395                    self.send_buf_size = capacity;
396                }
397                SO_RCVBUF => {
398                    let requested_capacity: socklen_t = optval.read(current_task)?;
399                    // SO_RCVBUF doubles the requested capacity to leave space for bookkeeping.
400                    // See https://man7.org/linux/man-pages/man7/socket.7.html
401                    let capacity = usize::try_from(requested_capacity * 2).unwrap_or(usize::MAX);
402                    // TODO(https://fxbug.dev/322906968): Clamp to `rmem_max`.
403                    let capacity = capacity.clamp(SOCKET_MIN_SIZE, SOCKET_MAX_SIZE);
404                    self.receive_buffer.set_capacity(capacity)?;
405                }
406                SO_RCVBUFFORCE => {
407                    security::check_task_capable(current_task, CAP_NET_ADMIN)?;
408                    let requested_capacity: socklen_t = optval.read(current_task)?;
409                    // SO_RCVBUFFORE doubles the requested capacity to leave space for bookkeeping.
410                    // See https://man7.org/linux/man-pages/man7/socket.7.html
411                    let capacity = usize::try_from(requested_capacity * 2).unwrap_or(usize::MAX);
412                    self.receive_buffer.set_capacity(capacity)?;
413                }
414                SO_PASSCRED => {
415                    let passcred: u32 = optval.read(current_task)?;
416                    self.passcred = passcred != 0;
417                }
418                SO_TIMESTAMP => {
419                    let timestamp: u32 = optval.read(current_task)?;
420                    self.timestamp = timestamp != 0;
421                }
422                _ => return error!(ENOSYS),
423            },
424            SOL_NETLINK => match optname {
425                NETLINK_GET_STRICT_CHK => {
426                    let strict_chk: u32 = optval.read(current_task)?;
427                    self.strict_chk = strict_chk != 0;
428                }
429                _ => return error!(ENOSYS),
430            },
431            _ => return error!(ENOSYS),
432        }
433
434        Ok(())
435    }
436}
437
438/// A fake Netlink socket that loops messages back to the client.
439///
440/// Used as a placeholder implementation for protocol families that lack a real
441/// implementation.
442struct StubbedNetlinkSocket {
443    inner: Mutex<NetlinkSocketInner>,
444}
445
446impl StubbedNetlinkSocket {
447    pub fn new(family: NetlinkFamily) -> Self {
448        track_stub!(
449            TODO("https://fxbug.dev/278565021"),
450            format!("Creating StubbedNetlinkSocket: {:?}", family).as_str()
451        );
452        StubbedNetlinkSocket { inner: Mutex::new(NetlinkSocketInner::new(family)) }
453    }
454
455    /// Locks and returns the inner state of the Socket.
456    fn lock(&self) -> starnix_sync::MutexGuard<'_, NetlinkSocketInner> {
457        self.inner.lock()
458    }
459}
460
461impl SocketOps for StubbedNetlinkSocket {
462    fn connect(
463        &self,
464        _locked: &mut Locked<FileOpsCore>,
465        _socket: &SocketHandle,
466        current_task: &CurrentTask,
467        peer: SocketPeer,
468    ) -> Result<(), Errno> {
469        self.lock().connect(current_task, peer)
470    }
471
472    fn listen(
473        &self,
474        _locked: &mut Locked<FileOpsCore>,
475        _socket: &Socket,
476        _backlog: i32,
477        _credentials: ucred,
478    ) -> Result<(), Errno> {
479        error!(EOPNOTSUPP)
480    }
481
482    fn accept(
483        &self,
484        _locked: &mut Locked<FileOpsCore>,
485        _socket: &Socket,
486        _current_task: &CurrentTask,
487    ) -> Result<SocketHandle, Errno> {
488        error!(EOPNOTSUPP)
489    }
490
491    fn bind(
492        &self,
493        _locked: &mut Locked<FileOpsCore>,
494        _socket: &Socket,
495        current_task: &CurrentTask,
496        socket_address: SocketAddress,
497    ) -> Result<(), Errno> {
498        self.lock().bind(current_task, socket_address)
499    }
500
501    fn read(
502        &self,
503        _locked: &mut Locked<FileOpsCore>,
504        _socket: &Socket,
505        _current_task: &CurrentTask,
506        data: &mut dyn OutputBuffer,
507        _flags: SocketMessageFlags,
508    ) -> Result<MessageReadInfo, Errno> {
509        let msg = self.lock().read_message();
510        match msg {
511            Some(message) => {
512                // Mark the message as complete and return it.
513                let (mut nl_msg, _) =
514                    nlmsghdr::read_from_prefix(&message.data).map_err(|_| errno!(EINVAL))?;
515                nl_msg.nlmsg_type = NLMSG_DONE as u16;
516                nl_msg.nlmsg_flags &= NLM_F_MULTI as u16;
517                let msg_bytes = nl_msg.as_bytes();
518                let bytes_read = data.write(msg_bytes)?;
519
520                let info = MessageReadInfo {
521                    bytes_read,
522                    message_length: msg_bytes.len(),
523                    address: Some(SocketAddress::Netlink(NetlinkAddress::default())),
524                    ancillary_data: vec![],
525                };
526                Ok(info)
527            }
528            None => Ok(MessageReadInfo::default()),
529        }
530    }
531
532    fn write(
533        &self,
534        _locked: &mut Locked<FileOpsCore>,
535        _socket: &Socket,
536        _current_task: &CurrentTask,
537        data: &mut dyn InputBuffer,
538        dest_address: &mut Option<SocketAddress>,
539        ancillary_data: &mut Vec<AncillaryData>,
540    ) -> Result<usize, Errno> {
541        let mut local_address = self.lock().address.clone();
542
543        let destination = match dest_address {
544            Some(SocketAddress::Netlink(addr)) => addr,
545            _ => match &mut local_address {
546                Some(addr) => addr,
547                _ => return Ok(data.drain()),
548            },
549        };
550
551        if destination.groups != 0 {
552            track_stub!(TODO("https://fxbug.dev/322874956"), "StubbedNetlinkSockets multicasting");
553            return Ok(data.drain());
554        }
555
556        self.lock().write_to_queue(data, Some(NetlinkAddress::default()), ancillary_data)
557    }
558
559    fn wait_async(
560        &self,
561        _locked: &mut Locked<FileOpsCore>,
562        _socket: &Socket,
563        _current_task: &CurrentTask,
564        waiter: &Waiter,
565        events: FdEvents,
566        handler: EventHandler,
567    ) -> WaitCanceler {
568        self.lock().wait_async(waiter, events, handler)
569    }
570
571    fn query_events(
572        &self,
573        _locked: &mut Locked<FileOpsCore>,
574        _socket: &Socket,
575        _current_task: &CurrentTask,
576    ) -> Result<FdEvents, Errno> {
577        Ok(self.lock().query_events() & FdEvents::POLLIN)
578    }
579
580    fn shutdown(
581        &self,
582        _locked: &mut Locked<FileOpsCore>,
583        _socket: &Socket,
584        _how: SocketShutdownFlags,
585    ) -> Result<(), Errno> {
586        track_stub!(TODO("https://fxbug.dev/322875507"), "StubbedNetlinkSocket::shutdown");
587        Ok(())
588    }
589
590    fn close(
591        &self,
592        _locked: &mut Locked<FileOpsCore>,
593        _current_task: &CurrentTask,
594        _socket: &Socket,
595    ) {
596    }
597
598    fn getsockname(
599        &self,
600        _locked: &mut Locked<FileOpsCore>,
601        _socket: &Socket,
602    ) -> Result<SocketAddress, Errno> {
603        self.lock().getsockname()
604    }
605
606    fn getpeername(
607        &self,
608        _locked: &mut Locked<FileOpsCore>,
609        _socket: &Socket,
610    ) -> Result<SocketAddress, Errno> {
611        self.lock().getpeername()
612    }
613
614    fn getsockopt(
615        &self,
616        _locked: &mut Locked<FileOpsCore>,
617        _socket: &Socket,
618        _current_task: &CurrentTask,
619        level: u32,
620        optname: u32,
621        _optlen: u32,
622    ) -> Result<Vec<u8>, Errno> {
623        self.lock().getsockopt(level, optname)
624    }
625
626    fn setsockopt(
627        &self,
628        _locked: &mut Locked<FileOpsCore>,
629        _socket: &Socket,
630        current_task: &CurrentTask,
631        level: u32,
632        optname: u32,
633        optval: SockOptValue,
634    ) -> Result<(), Errno> {
635        self.lock().setsockopt(current_task, level, optname, optval)
636    }
637}
638
639/// Socket implementation for the NETLINK_KOBJECT_UEVENT family of netlink sockets.
640struct UEventNetlinkSocket {
641    inner: Arc<Mutex<NetlinkSocketInner>>,
642    device_listener_key: Mutex<Option<DeviceListenerKey>>,
643}
644
645impl Default for UEventNetlinkSocket {
646    #[allow(clippy::let_and_return)]
647    fn default() -> Self {
648        let result = Self {
649            inner: Arc::new(Mutex::new(NetlinkSocketInner::new(NetlinkFamily::KobjectUevent))),
650            device_listener_key: Default::default(),
651        };
652        #[cfg(any(test, debug_assertions))]
653        {
654            let _l1 = result.device_listener_key.lock();
655            let _l2 = result.lock();
656        }
657        result
658    }
659}
660
661impl UEventNetlinkSocket {
662    /// Locks and returns the inner state of the Socket.
663    fn lock(&self) -> starnix_sync::MutexGuard<'_, NetlinkSocketInner> {
664        self.inner.lock()
665    }
666
667    fn register_listener<L>(
668        &self,
669        locked: &mut Locked<L>,
670        current_task: &CurrentTask,
671        state: starnix_sync::MutexGuard<'_, NetlinkSocketInner>,
672    ) where
673        L: LockEqualOrBefore<FileOpsCore>,
674    {
675        if state.address.is_none() {
676            return;
677        }
678        std::mem::drop(state);
679        let mut key_state = self.device_listener_key.lock();
680        if key_state.is_none() {
681            *key_state = Some(
682                current_task.kernel().device_registry.register_listener(locked, self.inner.clone()),
683            );
684        }
685    }
686}
687
688impl SocketOps for UEventNetlinkSocket {
689    fn connect(
690        &self,
691        locked: &mut Locked<FileOpsCore>,
692        _socket: &SocketHandle,
693        current_task: &CurrentTask,
694        peer: SocketPeer,
695    ) -> Result<(), Errno> {
696        let mut state = self.lock();
697        state.connect(current_task, peer)?;
698        self.register_listener(locked, current_task, state);
699        Ok(())
700    }
701
702    fn listen(
703        &self,
704        _locked: &mut Locked<FileOpsCore>,
705        _socket: &Socket,
706        _backlog: i32,
707        _credentials: ucred,
708    ) -> Result<(), Errno> {
709        error!(EOPNOTSUPP)
710    }
711
712    fn accept(
713        &self,
714        _locked: &mut Locked<FileOpsCore>,
715        _socket: &Socket,
716        _current_task: &CurrentTask,
717    ) -> Result<SocketHandle, Errno> {
718        error!(EOPNOTSUPP)
719    }
720
721    fn bind(
722        &self,
723        locked: &mut Locked<FileOpsCore>,
724        _socket: &Socket,
725        current_task: &CurrentTask,
726        socket_address: SocketAddress,
727    ) -> Result<(), Errno> {
728        let mut state = self.lock();
729        state.bind(current_task, socket_address)?;
730        self.register_listener(locked, current_task, state);
731        Ok(())
732    }
733
734    fn read(
735        &self,
736        _locked: &mut Locked<FileOpsCore>,
737        _socket: &Socket,
738        _current_task: &CurrentTask,
739        data: &mut dyn OutputBuffer,
740        flags: SocketMessageFlags,
741    ) -> Result<MessageReadInfo, Errno> {
742        self.lock().read_datagram(data, flags)
743    }
744
745    fn write(
746        &self,
747        _locked: &mut Locked<FileOpsCore>,
748        _socket: &Socket,
749        _current_task: &CurrentTask,
750        _data: &mut dyn InputBuffer,
751        _dest_address: &mut Option<SocketAddress>,
752        _ancillary_data: &mut Vec<AncillaryData>,
753    ) -> Result<usize, Errno> {
754        error!(EOPNOTSUPP)
755    }
756
757    fn wait_async(
758        &self,
759        _locked: &mut Locked<FileOpsCore>,
760        _socket: &Socket,
761        _current_task: &CurrentTask,
762        waiter: &Waiter,
763        events: FdEvents,
764        handler: EventHandler,
765    ) -> WaitCanceler {
766        self.lock().wait_async(waiter, events, handler)
767    }
768
769    fn query_events(
770        &self,
771        _locked: &mut Locked<FileOpsCore>,
772        _socket: &Socket,
773        _current_task: &CurrentTask,
774    ) -> Result<FdEvents, Errno> {
775        Ok(self.lock().query_events() & FdEvents::POLLIN)
776    }
777
778    fn shutdown(
779        &self,
780        _locked: &mut Locked<FileOpsCore>,
781        _socket: &Socket,
782        _how: SocketShutdownFlags,
783    ) -> Result<(), Errno> {
784        track_stub!(TODO("https://fxbug.dev/322875507"), "UEventNetlinkSocket::shutdown");
785        Ok(())
786    }
787
788    fn close(
789        &self,
790        locked: &mut Locked<FileOpsCore>,
791        current_task: &CurrentTask,
792        _socket: &Socket,
793    ) {
794        let id = self.device_listener_key.lock().take();
795        if let Some(id) = id {
796            current_task.kernel().device_registry.unregister_listener(locked, &id);
797        }
798    }
799
800    fn getsockname(
801        &self,
802        _locked: &mut Locked<FileOpsCore>,
803        _socket: &Socket,
804    ) -> Result<SocketAddress, Errno> {
805        self.lock().getsockname()
806    }
807
808    fn getpeername(
809        &self,
810        _locked: &mut Locked<FileOpsCore>,
811        _socket: &Socket,
812    ) -> Result<SocketAddress, Errno> {
813        self.lock().getpeername()
814    }
815
816    fn getsockopt(
817        &self,
818        _locked: &mut Locked<FileOpsCore>,
819        _socket: &Socket,
820        _current_task: &CurrentTask,
821        level: u32,
822        optname: u32,
823        _optlen: u32,
824    ) -> Result<Vec<u8>, Errno> {
825        self.lock().getsockopt(level, optname)
826    }
827
828    fn setsockopt(
829        &self,
830        _locked: &mut Locked<FileOpsCore>,
831        _socket: &Socket,
832        current_task: &CurrentTask,
833        level: u32,
834        optname: u32,
835        optval: SockOptValue,
836    ) -> Result<(), Errno> {
837        self.lock().setsockopt(current_task, level, optname, optval)
838    }
839}
840
841impl DeviceListener for Arc<Mutex<NetlinkSocketInner>> {
842    fn on_device_event(&self, action: UEventAction, device: Device, context: UEventContext) {
843        let path = device.path_from_depth(0);
844        let message = format!(
845            "{action}@/{path}\0\
846                            ACTION={action}\0\
847                            SEQNUM={seqnum}\0\
848                            {other_props}",
849            seqnum = context.seqnum,
850            other_props = device.uevent_properties('\0'),
851        );
852        let ancillary_data = AncillaryData::Unix(UnixControlData::Credentials(Default::default()));
853        let mut ancillary_data = vec![ancillary_data];
854        // Ignore write errors
855        let _ = self.lock().write_to_queue(
856            &mut VecInputBuffer::new(message.as_bytes()),
857            Some(NetlinkAddress { pid: 0, groups: 1 }),
858            &mut ancillary_data,
859        );
860    }
861}
862
863/// Type for sending messages from [`netlink::Netlink`] to an individual socket.
864#[derive(Clone)]
865pub struct NetlinkToClientSender<M> {
866    /// The inner socket implementation, which holds a message queue.
867    inner: Arc<Mutex<NetlinkSocketInner>>,
868
869    /// `PhantomData<fn(M) -> M>` is used instead of `PhantomData<M>` in order
870    /// to ensure that the type is invariant over `M` and that it implements
871    /// `Sync` even if `M` is not `Sync`.
872    _message_type: PhantomData<fn(M) -> M>,
873}
874
875impl<M> NetlinkToClientSender<M> {
876    fn new(inner: Arc<Mutex<NetlinkSocketInner>>) -> Self {
877        NetlinkToClientSender { _message_type: Default::default(), inner }
878    }
879}
880
881impl<M: Clone + NetlinkSerializable + Send> Sender<M> for NetlinkToClientSender<M> {
882    fn send(&mut self, message: NetlinkMessage<M>, group: Option<ModernGroup>) {
883        // Serialize the message
884        let mut buf = vec![0; message.buffer_len()];
885        message.emit(&mut buf);
886        let mut buf: VecInputBuffer = buf.into();
887        // Write the message into the inner socket buffer.
888        let NetlinkToClientSender { _message_type: _, inner } = self;
889        let mut guard = inner.lock();
890
891        // To avoid dropping messages when the receive buffer is
892        // full, grow the buffer on behalf of the client.
893        // This is a stop gap measure to avoid dropping messages
894        // when netlink produces a large response to a
895        // NLM_F_DUMP request.
896        //
897        // TODO(https://fxbug.dev/459883760): The memory
898        // implications of this may be problematic. It should be
899        // replaced with a proper mechanism to handle a backlog
900        // of NLM_F_DUMP responses.
901        let available = guard.receive_buffer.available_capacity();
902        let required = buf.available();
903        if available < required {
904            let delta = required - available;
905            let current_capacity = guard.receive_buffer.capacity();
906            let new_capacity = (current_capacity + delta).min(SOCKET_MAX_SIZE);
907            match guard.receive_buffer.set_capacity(new_capacity) {
908                Ok(()) => {}
909                Err(e) => {
910                    log_error!(
911                        tag = NETLINK_LOG_TAG;
912                        "Failed to increase receive buffer size: {:?}",
913                        e
914                    );
915                }
916            }
917        }
918
919        let _bytes_written: usize = guard
920            .write_to_queue(
921                &mut buf,
922                Some(NetlinkAddress {
923                    // All messages come from the "kernel" which has PID of 0.
924                    pid: 0,
925                    // If this is a multicast message, set the group the multicast
926                    // message is from.
927                    groups: group
928                        .map(SingleLegacyGroup::try_from)
929                        .and_then(Result::<_, NoMappingFromModernToLegacyGroupError>::ok)
930                        .map_or(0, |g| g.inner()),
931                }),
932                &mut Vec::new(),
933            )
934            .unwrap_or_else(|e| {
935                log_error!(
936                    tag = NETLINK_LOG_TAG;
937                    "Failed to write message into buffer for socket. Errno: {:?}",
938                    e
939                );
940                0
941            });
942    }
943}
944
945#[derive(Clone)]
946pub struct NetlinkAccessControl<'a> {
947    current_task: &'a CurrentTask,
948}
949
950impl<'a> NetlinkAccessControl<'a> {
951    pub fn new(current_task: &'a CurrentTask) -> Self {
952        Self { current_task }
953    }
954}
955
956impl<'a> AccessControl<Arc<Credentials>> for NetlinkAccessControl<'a> {
957    fn grant_assess(
958        &self,
959        creds: &Arc<Credentials>,
960        permission: Permission,
961    ) -> Result<(), netlink::Errno> {
962        let need_cap_net_admin = match permission {
963            Permission::NetlinkRouteRead => false,
964            Permission::NetlinkRouteWrite => true,
965            Permission::NetlinkSockDiagRead => false,
966            Permission::NetlinkSockDiagDestroy => true,
967        };
968        if !need_cap_net_admin {
969            return Ok(());
970        }
971
972        self.current_task.override_creds(creds.clone(), || {
973            security::check_task_capable(self.current_task, CAP_NET_ADMIN).map_err(|error| {
974                netlink::Errno::new(error.code.error_code() as i32)
975                    .expect("Errno::error_code() is expected to be in range [1..max_i32]")
976            })
977        })
978    }
979}
980pub struct NetlinkContextImpl;
981
982impl NetlinkContext for NetlinkContextImpl {
983    type Creds = Arc<Credentials>;
984    type Sender<M: Clone + NetlinkSerializable + Send> = NetlinkToClientSender<M>;
985    type Receiver<
986        M: Send + MessageWithPermission + NetlinkDeserializable<Error: Into<DecodeError>>,
987    > = UnboundedReceiver<NetlinkMessageWithCreds<UnparsedNetlinkMessage<Vec<u8>, M>, Self::Creds>>;
988    type AccessControl<'a> = NetlinkAccessControl<'a>;
989}
990
991fn new_route_socket(kernel: &Arc<Kernel>) -> Result<NetlinkSocket<NetlinkRouteClient>, Errno> {
992    let inner = Arc::new(Mutex::new(NetlinkSocketInner::new(NetlinkFamily::Route)));
993    let (message_sender, message_receiver) = mpsc::unbounded();
994    let client = match kernel
995        .network_netlink()
996        .new_route_client(NetlinkToClientSender::new(inner.clone()), message_receiver)
997    {
998        Ok(client) => client,
999        Err(NewClientError::Disconnected) => {
1000            log_error!(
1001                tag = NETLINK_LOG_TAG;
1002                "Netlink async worker is unexpectedly disconnected"
1003            );
1004            return error!(EPIPE);
1005        }
1006    };
1007    Ok(NetlinkSocket { inner, client, message_sender })
1008}
1009
1010fn new_sock_diag_socket(
1011    kernel: &Arc<Kernel>,
1012) -> Result<NetlinkSocket<NetlinkSockDiagClient>, Errno> {
1013    let inner = Arc::new(Mutex::new(NetlinkSocketInner::new(NetlinkFamily::SockDiag)));
1014    let (message_sender, message_receiver) = mpsc::unbounded();
1015    let client = match kernel
1016        .network_netlink()
1017        .new_sock_diag_client(NetlinkToClientSender::new(inner.clone()), message_receiver)
1018    {
1019        Ok(client) => client,
1020        Err(NewClientError::Disconnected) => {
1021            log_error!(
1022                tag = NETLINK_LOG_TAG;
1023                "Netlink async worker is unexpectedly disconnected"
1024            );
1025            return error!(EPIPE);
1026        }
1027    };
1028    Ok(NetlinkSocket { inner, client, message_sender })
1029}
1030
1031/// An abstraction over common networking-specific netlink sockets.
1032struct NetlinkSocket<C: NetlinkClient> {
1033    /// The inner Netlink socket implementation
1034    inner: Arc<Mutex<NetlinkSocketInner>>,
1035    /// The implementation of a client (socket connection) to a netlink protocol
1036    /// family.
1037    client: C,
1038    /// The sender of messages from this socket to Netlink.
1039    // TODO(https://issuetracker.google.com/285880057): Bound the capacity of
1040    // the "send buffer".
1041    message_sender: UnboundedSender<
1042        NetlinkMessageWithCreds<UnparsedNetlinkMessage<Vec<u8>, C::Request>, Arc<Credentials>>,
1043    >,
1044}
1045
1046/// A type that provides Netlink message deserialization options.
1047trait DeserializeOptionsProvider {
1048    /// The type of the message to deserialize.
1049    type Message: NetlinkDeserializable;
1050    /// The options to use when deserializing a `Message`.
1051    fn options(&self) -> <Self::Message as NetlinkDeserializable>::DeserializeOptions;
1052}
1053
1054impl DeserializeOptionsProvider for NetlinkSocket<NetlinkRouteClient> {
1055    type Message = RouteNetlinkMessage;
1056    fn options(&self) -> RouteNetlinkMessageParseMode {
1057        let strict = self.inner.lock().strict_chk;
1058        if strict {
1059            RouteNetlinkMessageParseMode::Strict
1060        } else {
1061            RouteNetlinkMessageParseMode::Relaxed
1062        }
1063    }
1064}
1065
1066impl DeserializeOptionsProvider for NetlinkSocket<NetlinkSockDiagClient> {
1067    type Message = SockDiagRequest;
1068    fn options(&self) -> EmptyDeserializeSockDiagOptions {
1069        EmptyDeserializeSockDiagOptions
1070    }
1071}
1072
1073impl<C: NetlinkClient + 'static> SocketOps for NetlinkSocket<C>
1074where
1075    Self: DeserializeOptionsProvider<Message = C::Request>,
1076{
1077    fn connect(
1078        &self,
1079        _locked: &mut Locked<FileOpsCore>,
1080        _socket: &SocketHandle,
1081        current_task: &CurrentTask,
1082        peer: SocketPeer,
1083    ) -> Result<(), Errno> {
1084        let NetlinkSocket { inner, client: _, message_sender: _ } = self;
1085        inner.lock().connect(current_task, peer)
1086    }
1087
1088    fn listen(
1089        &self,
1090        _locked: &mut Locked<FileOpsCore>,
1091        _socket: &Socket,
1092        _backlog: i32,
1093        _credentials: ucred,
1094    ) -> Result<(), Errno> {
1095        error!(EOPNOTSUPP)
1096    }
1097
1098    fn accept(
1099        &self,
1100        _locked: &mut Locked<FileOpsCore>,
1101        _socket: &Socket,
1102        _current_task: &CurrentTask,
1103    ) -> Result<SocketHandle, Errno> {
1104        error!(EOPNOTSUPP)
1105    }
1106
1107    fn bind(
1108        &self,
1109        _locked: &mut Locked<FileOpsCore>,
1110        _socket: &Socket,
1111        current_task: &CurrentTask,
1112        socket_address: SocketAddress,
1113    ) -> Result<(), Errno> {
1114        let NetlinkSocket { inner, client, message_sender: _ } = self;
1115
1116        let multicast_groups = match &socket_address {
1117            SocketAddress::Netlink(NetlinkAddress { pid: _, groups }) => *groups,
1118            _ => return error!(EINVAL),
1119        };
1120        let pid = {
1121            let mut inner = inner.lock();
1122            inner.bind(current_task, socket_address)?;
1123            inner
1124                .address
1125                .as_ref()
1126                .and_then(|NetlinkAddress { pid, groups: _ }| NonZeroU32::new(*pid))
1127        };
1128        if let Some(pid) = pid {
1129            client.set_pid(pid);
1130        }
1131        // This "blocks" in order to synchronize with the internal
1132        // state of the netlink worker, but we're not blocking on
1133        // the completion of any i/o or any expensive computation,
1134        // so there's no need to support interrupts here.
1135        client
1136            .set_legacy_memberships(LegacyGroups(multicast_groups))
1137            .map_err(|InvalidLegacyGroupsError {}| errno!(EPERM))?
1138            .wait_until_complete();
1139        Ok(())
1140    }
1141
1142    fn read(
1143        &self,
1144        _locked: &mut Locked<FileOpsCore>,
1145        _socket: &Socket,
1146        _current_task: &CurrentTask,
1147        data: &mut dyn OutputBuffer,
1148        flags: SocketMessageFlags,
1149    ) -> Result<MessageReadInfo, Errno> {
1150        let NetlinkSocket { inner, client: _, message_sender: _ } = self;
1151        inner.lock().read_datagram(data, flags)
1152    }
1153
1154    fn write(
1155        &self,
1156        _locked: &mut Locked<FileOpsCore>,
1157        socket: &Socket,
1158        current_task: &CurrentTask,
1159        data: &mut dyn InputBuffer,
1160        _dest_address: &mut Option<SocketAddress>,
1161        _ancillary_data: &mut Vec<AncillaryData>,
1162    ) -> Result<usize, Errno> {
1163        let NetlinkSocket { inner: _, client: _, message_sender } = self;
1164
1165        let bytes = data.peek_all()?;
1166        let bytes_len = bytes.len();
1167
1168        // Parse only the netlink header to send it through security check.
1169        match NetlinkBuffer::new(&bytes) {
1170            Ok(buffer) => {
1171                security::check_netlink_send_access(current_task, socket, buffer.message_type())?;
1172            }
1173            Err(e) => {
1174                // If we can't even decode the header of the netlink message,
1175                // then return early here as a stronger statement that we're not
1176                // going to accidentally operate on it and violate the security
1177                // check. The netlink crate would end up dropping this with no
1178                // response as well.
1179                log_warn!(tag = NETLINK_LOG_TAG;
1180                    "Failed to parse netlink header {e:?}"
1181                );
1182                data.drain();
1183                return Ok(bytes_len);
1184            }
1185        }
1186
1187        let msg = NetlinkMessageWithCreds::new(
1188            UnparsedNetlinkMessage::new(bytes, self.options()),
1189            current_task.current_creds().clone(),
1190        );
1191        message_sender.unbounded_send(msg).map_err(|e| {
1192            log_warn!(
1193                tag = NETLINK_LOG_TAG;
1194                "Netlink receiver unexpectedly disconnected for socket: {:?}",
1195                e
1196            );
1197            errno!(EPIPE)
1198        })?;
1199        data.drain();
1200        Ok(bytes_len)
1201    }
1202
1203    fn wait_async(
1204        &self,
1205        _locked: &mut Locked<FileOpsCore>,
1206        _socket: &Socket,
1207        _current_task: &CurrentTask,
1208        waiter: &Waiter,
1209        events: FdEvents,
1210        handler: EventHandler,
1211    ) -> WaitCanceler {
1212        let NetlinkSocket { inner, client: _, message_sender: _ } = self;
1213        inner.lock().wait_async(waiter, events, handler)
1214    }
1215
1216    fn query_events(
1217        &self,
1218        _locked: &mut Locked<FileOpsCore>,
1219        _socket: &Socket,
1220        _current_task: &CurrentTask,
1221    ) -> Result<FdEvents, Errno> {
1222        let NetlinkSocket { inner, client: _, message_sender: _ } = self;
1223        Ok(inner.lock().query_events() & FdEvents::POLLIN)
1224    }
1225
1226    fn shutdown(
1227        &self,
1228        _locked: &mut Locked<FileOpsCore>,
1229        _socket: &Socket,
1230        _how: SocketShutdownFlags,
1231    ) -> Result<(), Errno> {
1232        error!(EOPNOTSUPP)
1233    }
1234
1235    fn close(
1236        &self,
1237        _locked: &mut Locked<FileOpsCore>,
1238        _current_task: &CurrentTask,
1239        _socket: &Socket,
1240    ) {
1241        // Close the underlying channel to the Netlink worker.
1242        self.message_sender.close_channel();
1243    }
1244
1245    fn getsockname(
1246        &self,
1247        _locked: &mut Locked<FileOpsCore>,
1248        _socket: &Socket,
1249    ) -> Result<SocketAddress, Errno> {
1250        let NetlinkSocket { inner, client: _, message_sender: _ } = self;
1251        inner.lock().getsockname()
1252    }
1253
1254    fn getpeername(
1255        &self,
1256        _locked: &mut Locked<FileOpsCore>,
1257        _socket: &Socket,
1258    ) -> Result<SocketAddress, Errno> {
1259        self.inner.lock().getpeername()
1260    }
1261
1262    fn getsockopt(
1263        &self,
1264        _locked: &mut Locked<FileOpsCore>,
1265        _socket: &Socket,
1266        _current_task: &CurrentTask,
1267        level: u32,
1268        optname: u32,
1269        _optlen: u32,
1270    ) -> Result<Vec<u8>, Errno> {
1271        self.inner.lock().getsockopt(level, optname)
1272    }
1273
1274    fn setsockopt(
1275        &self,
1276        _locked: &mut Locked<FileOpsCore>,
1277        _socket: &Socket,
1278        current_task: &CurrentTask,
1279        level: u32,
1280        optname: u32,
1281        optval: SockOptValue,
1282    ) -> Result<(), Errno> {
1283        match (level, optname) {
1284            (SOL_NETLINK, NETLINK_ADD_MEMBERSHIP) => {
1285                let NetlinkSocket { inner: _, client, message_sender: _ } = self;
1286                let group: u32 = optval.read(current_task)?;
1287                let async_work = client
1288                    .add_membership(ModernGroup(group))
1289                    .map_err(|InvalidModernGroupError| errno!(EINVAL))?;
1290                // This "blocks" in order to synchronize with the internal
1291                // state of the rtnetlink worker, but we're not blocking on
1292                // the completion of any i/o or any expensive computation,
1293                // so there's no need to support interrupts here.
1294                async_work.wait_until_complete();
1295                Ok(())
1296            }
1297            (SOL_NETLINK, NETLINK_DROP_MEMBERSHIP) => {
1298                let NetlinkSocket { inner: _, client, message_sender: _ } = self;
1299                let group: u32 = optval.read(current_task)?;
1300                client
1301                    .del_membership(ModernGroup(group))
1302                    .map_err(|InvalidModernGroupError| errno!(EINVAL))?;
1303                Ok(())
1304            }
1305            _ => self.inner.lock().setsockopt(current_task, level, optname, optval),
1306        }
1307    }
1308}
1309
1310/// Socket implementation for the NETLINK_GENERIC family of netlink sockets.
1311struct GenericNetlinkSocket {
1312    inner: Arc<Mutex<NetlinkSocketInner>>,
1313    client: GenericNetlinkClientHandle<NetlinkToClientSender<GenericMessage>>,
1314    message_sender: mpsc::UnboundedSender<NetlinkMessage<GenericMessage>>,
1315}
1316
1317impl GenericNetlinkSocket {
1318    pub fn new(kernel: &Kernel) -> Result<Self, Errno> {
1319        let inner = Arc::new(Mutex::new(NetlinkSocketInner::new(NetlinkFamily::Generic)));
1320        let (message_sender, message_receiver) = mpsc::unbounded();
1321        match kernel
1322            .generic_netlink()
1323            .new_generic_client(NetlinkToClientSender::new(inner.clone()), message_receiver)
1324        {
1325            Ok(client) => Ok(Self { inner, client, message_sender }),
1326            Err(e) => {
1327                log_warn!(
1328                    tag = NETLINK_LOG_TAG;
1329                    "Failed to connect to generic netlink server. Errno: {:?}",
1330                    e
1331                );
1332                error!(EPIPE)
1333            }
1334        }
1335    }
1336
1337    /// Locks and returns the inner state of the Socket.
1338    fn lock(&self) -> starnix_sync::MutexGuard<'_, NetlinkSocketInner> {
1339        self.inner.lock()
1340    }
1341}
1342
1343impl SocketOps for GenericNetlinkSocket {
1344    fn connect(
1345        &self,
1346        _locked: &mut Locked<FileOpsCore>,
1347        _socket: &SocketHandle,
1348        current_task: &CurrentTask,
1349        peer: SocketPeer,
1350    ) -> Result<(), Errno> {
1351        let mut state = self.lock();
1352        state.connect(current_task, peer)
1353    }
1354
1355    fn listen(
1356        &self,
1357        _locked: &mut Locked<FileOpsCore>,
1358        _socket: &Socket,
1359        _backlog: i32,
1360        _credentials: ucred,
1361    ) -> Result<(), Errno> {
1362        error!(EOPNOTSUPP)
1363    }
1364
1365    fn accept(
1366        &self,
1367        _locked: &mut Locked<FileOpsCore>,
1368        _socket: &Socket,
1369        _current_task: &CurrentTask,
1370    ) -> Result<SocketHandle, Errno> {
1371        error!(EOPNOTSUPP)
1372    }
1373
1374    fn bind(
1375        &self,
1376        _locked: &mut Locked<FileOpsCore>,
1377        _socket: &Socket,
1378        current_task: &CurrentTask,
1379        socket_address: SocketAddress,
1380    ) -> Result<(), Errno> {
1381        let mut state = self.lock();
1382        state.bind(current_task, socket_address)
1383    }
1384
1385    fn read(
1386        &self,
1387        _locked: &mut Locked<FileOpsCore>,
1388        _socket: &Socket,
1389        _current_task: &CurrentTask,
1390        data: &mut dyn OutputBuffer,
1391        flags: SocketMessageFlags,
1392    ) -> Result<MessageReadInfo, Errno> {
1393        self.lock().read_datagram(data, flags)
1394    }
1395
1396    fn write(
1397        &self,
1398        _locked: &mut Locked<FileOpsCore>,
1399        _socket: &Socket,
1400        _current_task: &CurrentTask,
1401        data: &mut dyn InputBuffer,
1402        _dest_address: &mut Option<SocketAddress>,
1403        _ancillary_data: &mut Vec<AncillaryData>,
1404    ) -> Result<usize, Errno> {
1405        let bytes = data.read_all()?;
1406        match NetlinkMessage::<GenericMessage>::deserialize(&bytes, EmptyDeserializeGenlOptions) {
1407            Err(e) => {
1408                log_warn!("Failed to process write; data could not be deserialized: {:?}", e);
1409                error!(EINVAL)
1410            }
1411            Ok(msg) => match self.message_sender.unbounded_send(msg) {
1412                Ok(()) => Ok(bytes.len()),
1413                Err(e) => {
1414                    log_warn!("Netlink receiver unexpectedly disconnected for socket: {:?}", e);
1415                    error!(EPIPE)
1416                }
1417            },
1418        }
1419    }
1420
1421    fn wait_async(
1422        &self,
1423        _locked: &mut Locked<FileOpsCore>,
1424        _socket: &Socket,
1425        _current_task: &CurrentTask,
1426        waiter: &Waiter,
1427        events: FdEvents,
1428        handler: EventHandler,
1429    ) -> WaitCanceler {
1430        self.lock().wait_async(waiter, events, handler)
1431    }
1432
1433    fn query_events(
1434        &self,
1435        _locked: &mut Locked<FileOpsCore>,
1436        _socket: &Socket,
1437        _current_task: &CurrentTask,
1438    ) -> Result<FdEvents, Errno> {
1439        Ok(self.lock().query_events() & FdEvents::POLLIN)
1440    }
1441
1442    fn shutdown(
1443        &self,
1444        _locked: &mut Locked<FileOpsCore>,
1445        _socket: &Socket,
1446        _how: SocketShutdownFlags,
1447    ) -> Result<(), Errno> {
1448        track_stub!(TODO("https://fxbug.dev/322875507"), "GenericNetlinkSocket::shutdown");
1449        Ok(())
1450    }
1451
1452    fn close(
1453        &self,
1454        _locked: &mut Locked<FileOpsCore>,
1455        _current_task: &CurrentTask,
1456        _socket: &Socket,
1457    ) {
1458    }
1459
1460    fn getsockname(
1461        &self,
1462        _locked: &mut Locked<FileOpsCore>,
1463        _socket: &Socket,
1464    ) -> Result<SocketAddress, Errno> {
1465        self.lock().getsockname()
1466    }
1467
1468    fn getpeername(
1469        &self,
1470        _locked: &mut Locked<FileOpsCore>,
1471        _socket: &Socket,
1472    ) -> Result<SocketAddress, Errno> {
1473        self.lock().getpeername()
1474    }
1475
1476    fn getsockopt(
1477        &self,
1478        _locked: &mut Locked<FileOpsCore>,
1479        _socket: &Socket,
1480        _current_task: &CurrentTask,
1481        level: u32,
1482        optname: u32,
1483        _optlen: u32,
1484    ) -> Result<Vec<u8>, Errno> {
1485        self.lock().getsockopt(level, optname)
1486    }
1487
1488    fn setsockopt(
1489        &self,
1490        _locked: &mut Locked<FileOpsCore>,
1491        _socket: &Socket,
1492        current_task: &CurrentTask,
1493        level: u32,
1494        optname: u32,
1495        optval: SockOptValue,
1496    ) -> Result<(), Errno> {
1497        match (level, optname) {
1498            (SOL_NETLINK, NETLINK_ADD_MEMBERSHIP) => {
1499                let group_id: u32 = optval.read(current_task)?;
1500                self.client.add_membership(ModernGroup(group_id))
1501            }
1502            _ => self.lock().setsockopt(current_task, level, optname, optval),
1503        }
1504    }
1505}
1506
1507/// Audit client that can be attached to the `AuditLogger`.
1508pub struct AuditNetlinkClient {
1509    /// Reference to the `AuditLogger`.
1510    audit_logger: Arc<AuditLogger>,
1511    /// The waiters queue present in `AuditNetlinkSocket`.
1512    waiters: WaitQueue,
1513    /// Optional response from the `AuditLogger`.
1514    audit_response: Mutex<Option<NetlinkMessage<GenericMessage>>>,
1515}
1516
1517impl AuditNetlinkClient {
1518    fn new(audit_logger: Arc<AuditLogger>) -> Self {
1519        Self { audit_logger, waiters: Default::default(), audit_response: Mutex::new(None) }
1520    }
1521
1522    pub fn notify(&self) {
1523        self.waiters.notify_fd_events(FdEvents::POLLIN);
1524    }
1525
1526    /// Function to check the capabilities of the current task against CAP_AUDIT_*
1527    fn check_audit_access(
1528        &self,
1529        current_task: &CurrentTask,
1530        request_type: &AuditRequest,
1531    ) -> Result<(), Errno> {
1532        match request_type {
1533            AuditRequest::AuditGet | AuditRequest::AuditSet => {
1534                security::check_task_capable(current_task, CAP_AUDIT_CONTROL)
1535            }
1536            AuditRequest::AuditUser => security::check_task_capable(current_task, CAP_AUDIT_WRITE),
1537        }
1538    }
1539
1540    /// Function to process request coming from userspace, it returns the response after processing
1541    fn process_request(
1542        self: &Arc<Self>,
1543        current_task: &CurrentTask,
1544        nl_message: NetlinkMessage<GenericMessage>,
1545    ) -> Result<NetlinkMessage<GenericMessage>, Errno> {
1546        let (nl_header, nl_payload) = nl_message.into_parts();
1547        let audit_request_type = AuditRequest::try_from(nl_header.message_type as u32)?;
1548        self.check_audit_access(current_task, &audit_request_type)?;
1549
1550        // If there is no GenericMessage, return an ErrorMessage.
1551        let NetlinkPayload::InnerMessage(GenericMessage::Other { payload, .. }) = nl_payload else {
1552            return error!(EINVAL);
1553        };
1554        match audit_request_type {
1555            AuditRequest::AuditGet => self.process_get_status(nl_header.sequence_number),
1556            AuditRequest::AuditSet => self.process_set_status(current_task, nl_header, payload),
1557            AuditRequest::AuditUser => self.process_user_audit(nl_header, payload),
1558        }
1559    }
1560
1561    fn get_nl_response(&self, flags: SocketMessageFlags) -> Option<Vec<u8>> {
1562        if flags.contains(SocketMessageFlags::PEEK) {
1563            if let Some(message) = self.audit_response.lock().as_ref() {
1564                return Some(AuditNetlinkClient::serialize_nlmsg(message.clone()));
1565            }
1566        } else if let Some(message) = self.audit_response.lock().take() {
1567            return Some(AuditNetlinkClient::serialize_nlmsg(message));
1568        }
1569        None
1570    }
1571
1572    /// Function to read an audit message from `AuditLogger`.
1573    fn read_audit_log(self: &Arc<Self>) -> Option<Vec<u8>> {
1574        if let Some(AuditMessage { audit_type, message }) = self.audit_logger.read_audit_log(self) {
1575            return Some(AuditNetlinkClient::serialize_nlmsg(
1576                AuditNetlinkClient::build_audit_nlmsg(0, audit_type, message),
1577            ));
1578        }
1579        None
1580    }
1581
1582    /// Function to read the optional response if present or an audit message.
1583    fn read_nlmsg(self: &Arc<Self>, flags: SocketMessageFlags) -> Result<Vec<u8>, Errno> {
1584        // First check if there is a response and send it if present.
1585        // Send an audit message otherwise or return EAGAIN.
1586        self.get_nl_response(flags).or_else(|| self.read_audit_log()).ok_or_else(|| errno!(EAGAIN))
1587    }
1588
1589    fn process_get_status(
1590        &self,
1591        sequence_number: u32,
1592    ) -> Result<NetlinkMessage<GenericMessage>, Errno> {
1593        Ok(AuditNetlinkClient::build_audit_nlmsg(
1594            sequence_number,
1595            AUDIT_GET as u16,
1596            self.audit_logger.get_status().as_bytes().to_vec(),
1597        ))
1598    }
1599
1600    fn process_set_status(
1601        self: &Arc<Self>,
1602        current_task: &CurrentTask,
1603        nl_hdr: NetlinkHeader,
1604        nl_payload: Vec<u8>,
1605    ) -> Result<NetlinkMessage<GenericMessage>, Errno> {
1606        let Some(status) = audit_status::read_from_bytes(nl_payload.as_bytes()).ok() else {
1607            return error!(EINVAL);
1608        };
1609        self.audit_logger.set_status(current_task, status, self)?;
1610        Ok(AuditNetlinkClient::build_audit_ack(Ok(()), nl_hdr))
1611    }
1612
1613    fn process_user_audit(
1614        &self,
1615        nl_hdr: NetlinkHeader,
1616        nl_payload: Vec<u8>,
1617    ) -> Result<NetlinkMessage<GenericMessage>, Errno> {
1618        let audit_msg = String::from_utf8_lossy(nl_payload.as_bytes());
1619        self.audit_logger.audit_log(nl_hdr.message_type, move || audit_msg);
1620        Ok(AuditNetlinkClient::build_audit_ack(Ok(()), nl_hdr))
1621    }
1622
1623    fn query_events(self: &Arc<Self>) -> FdEvents {
1624        if self.audit_response.lock().is_some() || self.audit_logger.get_backlog_count(self) != 0 {
1625            return FdEvents::POLLIN;
1626        }
1627        FdEvents::empty()
1628    }
1629
1630    fn detach(self: &Arc<Self>) {
1631        self.audit_logger.detach_client(self);
1632    }
1633
1634    fn build_audit_nlmsg(
1635        seq_number: u32,
1636        msg_type: u16,
1637        payload: Vec<u8>,
1638    ) -> NetlinkMessage<GenericMessage> {
1639        // The family in GenericMessage can be used for message type, not only for the Netlink Family,
1640        // because after finalizing the message, the message type is equal to family.
1641        let nl_payload =
1642            NetlinkPayload::InnerMessage(GenericMessage::Other { family: msg_type, payload });
1643        let mut nl_header = NetlinkHeader::default();
1644        nl_header.sequence_number = seq_number;
1645        let mut message = NetlinkMessage::new(nl_header, nl_payload);
1646        message.finalize();
1647        message
1648    }
1649
1650    fn build_audit_ack(
1651        error: Result<(), Errno>,
1652        req_header: NetlinkHeader,
1653    ) -> NetlinkMessage<GenericMessage> {
1654        let error = {
1655            assert_eq!(req_header.buffer_len(), NETLINK_HEADER_LEN);
1656            let mut buffer = vec![0; NETLINK_HEADER_LEN];
1657            req_header.emit(&mut buffer);
1658
1659            let code = match error {
1660                Ok(()) => None,
1661                Err(e) => Some(
1662                    // Audit netlink errors are negative.
1663                    NonZeroI32::new(-(e.code.error_code() as i32))
1664                        .expect("Errno's code must be non-zero"),
1665                ),
1666            };
1667
1668            let mut error = ErrorMessage::default();
1669            error.code = code;
1670            error.header = buffer;
1671            error
1672        };
1673
1674        let payload = NetlinkPayload::<GenericMessage>::Error(error);
1675        let mut resp_header = NetlinkHeader::default();
1676        resp_header.message_type = NLMSG_ERROR;
1677        resp_header.sequence_number = req_header.sequence_number;
1678        let mut message = NetlinkMessage::new(resp_header, payload);
1679        message.finalize();
1680        message
1681    }
1682
1683    fn serialize_nlmsg(message: NetlinkMessage<GenericMessage>) -> Vec<u8> {
1684        let mut buf = vec![0; message.buffer_len()];
1685        message.serialize(&mut buf);
1686        buf
1687    }
1688}
1689
1690/// Audit Netlink Socket structure.
1691pub struct AuditNetlinkSocket {
1692    /// Reference to the `AuditNetlinkClient` associated with self.
1693    audit_client: Arc<AuditNetlinkClient>,
1694}
1695
1696impl AuditNetlinkSocket {
1697    pub fn new(kernel: &Kernel) -> Result<Self, Errno> {
1698        if kernel.audit_logger().is_disabled() {
1699            return error!(EPROTONOSUPPORT);
1700        }
1701        Ok(Self { audit_client: Arc::new(AuditNetlinkClient::new(kernel.audit_logger())) })
1702    }
1703}
1704
1705impl SocketOps for AuditNetlinkSocket {
1706    fn read(
1707        &self,
1708        _locked: &mut Locked<FileOpsCore>,
1709        _socket: &Socket,
1710        _current_task: &CurrentTask,
1711        data: &mut dyn OutputBuffer,
1712        flags: SocketMessageFlags,
1713    ) -> Result<MessageReadInfo, Errno> {
1714        let buf = self.audit_client.read_nlmsg(flags)?;
1715
1716        let size = data.write_all(buf.as_bytes())?;
1717        Ok(MessageReadInfo {
1718            bytes_read: size,
1719            message_length: size,
1720            address: Some(SocketAddress::Netlink(NetlinkAddress::default())),
1721            ancillary_data: vec![],
1722        })
1723    }
1724
1725    fn write(
1726        &self,
1727        _locked: &mut Locked<FileOpsCore>,
1728        socket: &Socket,
1729        current_task: &CurrentTask,
1730        data: &mut dyn InputBuffer,
1731        _dest_address: &mut Option<SocketAddress>,
1732        _ancillary_data: &mut Vec<AncillaryData>,
1733    ) -> Result<usize, Errno> {
1734        match NetlinkMessage::<GenericMessage>::deserialize(
1735            &(data.peek_all()?),
1736            EmptyDeserializeGenlOptions,
1737        ) {
1738            Ok(nl_message) => {
1739                let header = nl_message.header;
1740                security::check_netlink_send_access(current_task, socket, header.message_type)?;
1741
1742                // Send request to the `AuditNetlinkClient`.
1743                let audit_ack = self
1744                    .audit_client
1745                    .process_request(current_task, nl_message)
1746                    .map_err(|e| AuditNetlinkClient::build_audit_ack(Err(e), header))
1747                    .unwrap_or_else(|nlerr| nlerr);
1748                *self.audit_client.audit_response.lock() = Some(audit_ack);
1749                data.drain();
1750                Ok(header.length as usize)
1751            }
1752            Err(e) => {
1753                log_warn!("Failed to process write; data could not be deserialized: {:?}", e);
1754                error!(EINVAL)
1755            }
1756        }
1757    }
1758
1759    fn wait_async(
1760        &self,
1761        _locked: &mut Locked<FileOpsCore>,
1762        _socket: &Socket,
1763        _current_task: &CurrentTask,
1764        waiter: &Waiter,
1765        events: FdEvents,
1766        handler: EventHandler,
1767    ) -> WaitCanceler {
1768        self.audit_client.waiters.wait_async_fd_events(waiter, events, handler)
1769    }
1770
1771    fn query_events(
1772        &self,
1773        _locked: &mut Locked<FileOpsCore>,
1774        _socket: &Socket,
1775        _current_task: &CurrentTask,
1776    ) -> Result<FdEvents, Errno> {
1777        Ok(self.audit_client.query_events() & FdEvents::POLLIN)
1778    }
1779
1780    fn close(
1781        &self,
1782        _locked: &mut Locked<FileOpsCore>,
1783        _current_task: &CurrentTask,
1784        _socket: &Socket,
1785    ) {
1786        // If the `AuditNetlinkClient` disconnects, detach it.
1787        self.audit_client.detach();
1788    }
1789
1790    fn shutdown(
1791        &self,
1792        _locked: &mut Locked<FileOpsCore>,
1793        _socket: &Socket,
1794        _how: SocketShutdownFlags,
1795    ) -> Result<(), Errno> {
1796        error!(EOPNOTSUPP)
1797    }
1798
1799    fn connect(
1800        &self,
1801        _locked: &mut Locked<FileOpsCore>,
1802        _socket: &SocketHandle,
1803        _current_task: &CurrentTask,
1804        _peer: SocketPeer,
1805    ) -> Result<(), Errno> {
1806        error!(EOPNOTSUPP)
1807    }
1808
1809    fn listen(
1810        &self,
1811        _locked: &mut Locked<FileOpsCore>,
1812        _socket: &Socket,
1813        _backlog: i32,
1814        _credentials: ucred,
1815    ) -> Result<(), Errno> {
1816        error!(EOPNOTSUPP)
1817    }
1818
1819    fn accept(
1820        &self,
1821        _locked: &mut Locked<FileOpsCore>,
1822        _socket: &Socket,
1823        _current_task: &CurrentTask,
1824    ) -> Result<SocketHandle, Errno> {
1825        error!(EOPNOTSUPP)
1826    }
1827
1828    fn bind(
1829        &self,
1830        _locked: &mut Locked<FileOpsCore>,
1831        _socket: &Socket,
1832        _current_task: &CurrentTask,
1833        _socket_address: SocketAddress,
1834    ) -> Result<(), Errno> {
1835        error!(EOPNOTSUPP)
1836    }
1837
1838    fn getsockname(
1839        &self,
1840        _locked: &mut Locked<FileOpsCore>,
1841        _socket: &Socket,
1842    ) -> Result<SocketAddress, Errno> {
1843        error!(EOPNOTSUPP)
1844    }
1845
1846    fn getpeername(
1847        &self,
1848        _locked: &mut Locked<FileOpsCore>,
1849        _socket: &Socket,
1850    ) -> Result<SocketAddress, Errno> {
1851        error!(EOPNOTSUPP)
1852    }
1853
1854    fn getsockopt(
1855        &self,
1856        _locked: &mut Locked<FileOpsCore>,
1857        _socket: &Socket,
1858        _current_task: &CurrentTask,
1859        _level: u32,
1860        _optname: u32,
1861        _optlen: u32,
1862    ) -> Result<Vec<u8>, Errno> {
1863        error!(EOPNOTSUPP)
1864    }
1865
1866    fn setsockopt(
1867        &self,
1868        _locked: &mut Locked<FileOpsCore>,
1869        _socket: &Socket,
1870        _current_task: &CurrentTask,
1871        _level: u32,
1872        _optname: u32,
1873        _optval: SockOptValue,
1874    ) -> Result<(), Errno> {
1875        error!(EOPNOTSUPP)
1876    }
1877}
1878
1879#[cfg(test)]
1880mod tests {
1881    use super::*;
1882
1883    use netlink_packet_route::route::RouteMessage;
1884    use netlink_packet_route::{RouteNetlinkMessage, RouteNetlinkMessageParseMode};
1885    use test_case::test_case;
1886
1887    // Successfully send the message and observe it's stored in the queue.
1888    #[test_case(true; "sufficient_capacity")]
1889    // Attempting to send when the queue is full should succeed by increasing
1890    // the size of the queue.
1891    #[test_case(false; "insufficient_capacity")]
1892    fn test_netlink_to_client_sender(sufficient_capacity: bool) {
1893        const MODERN_GROUP: u32 = 5;
1894
1895        let mut message: NetlinkMessage<RouteNetlinkMessage> =
1896            RouteNetlinkMessage::NewRoute(RouteMessage::default()).into();
1897        message.finalize();
1898
1899        let (initial_queue_size, final_queue_size) = if sufficient_capacity {
1900            (SOCKET_DEFAULT_SIZE, SOCKET_DEFAULT_SIZE)
1901        } else {
1902            (0, message.buffer_len())
1903        };
1904
1905        let socket_inner = Arc::new(Mutex::new(NetlinkSocketInner {
1906            receive_buffer: MessageQueue::new(initial_queue_size),
1907            ..NetlinkSocketInner::new(NetlinkFamily::Route)
1908        }));
1909
1910        let mut sender = NetlinkToClientSender::<RouteNetlinkMessage>::new(socket_inner.clone());
1911        sender.send(message.clone(), Some(ModernGroup(MODERN_GROUP)));
1912        let Message { data, address, ancillary_data: _ } =
1913            socket_inner.lock().read_message().expect("should read message");
1914
1915        assert_eq!(
1916            address,
1917            Some(SocketAddress::Netlink(NetlinkAddress { pid: 0, groups: 1 << MODERN_GROUP }))
1918        );
1919        let actual_message = NetlinkMessage::<RouteNetlinkMessage>::deserialize(
1920            &data,
1921            RouteNetlinkMessageParseMode::Strict,
1922        )
1923        .expect("message should deserialize into RtnlMessage");
1924        assert_eq!(actual_message, message);
1925        assert_eq!(socket_inner.lock().receive_buffer.capacity(), final_queue_size);
1926    }
1927
1928    fn getsockopt_u32(socket: &NetlinkSocketInner, level: u32, optname: u32) -> u32 {
1929        let byte_vec = socket.getsockopt(level, optname).expect("getsockopt should succeed");
1930        let bytes: [u8; 4] = byte_vec.as_slice().try_into().expect("expected 4 bytes");
1931        u32::from_ne_bytes(bytes)
1932    }
1933
1934    fn sock_opt_value(val: u32) -> SockOptValue {
1935        SockOptValue::Value(val.to_ne_bytes().to_vec())
1936    }
1937
1938    #[::fuchsia::test]
1939    async fn test_set_get_snd_rcv_buf() {
1940        crate::testing::spawn_kernel_and_run_sync(|_locked, current_task| {
1941            let mut socket = NetlinkSocketInner::new(NetlinkFamily::Route);
1942
1943            // Verify initialization uses the default value.
1944            let expected_default = u32::try_from(SOCKET_DEFAULT_SIZE).unwrap();
1945            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_SNDBUF), expected_default);
1946            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_RCVBUF), expected_default);
1947
1948            // Set new values and observe that they were applied.
1949            // Note that applied value is 2 times the requested value.
1950            const SNDBUF_SIZE: u32 = 12345;
1951            const RCVBUF_SIZE: u32 = 54321;
1952            socket
1953                .setsockopt(current_task, SOL_SOCKET, SO_SNDBUF, sock_opt_value(SNDBUF_SIZE))
1954                .expect("setsockopt should succeed");
1955            socket
1956                .setsockopt(current_task, SOL_SOCKET, SO_RCVBUF, sock_opt_value(RCVBUF_SIZE))
1957                .expect("setsockopt should succeed");
1958            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_SNDBUF), SNDBUF_SIZE * 2);
1959            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_RCVBUF), RCVBUF_SIZE * 2);
1960        })
1961        .await;
1962    }
1963
1964    #[::fuchsia::test]
1965    async fn test_snd_rcv_buf_limits() {
1966        crate::testing::spawn_kernel_and_run_sync(|_locked, current_task| {
1967            let mut socket = NetlinkSocketInner::new(NetlinkFamily::Route);
1968            let too_big = u32::try_from(SOCKET_MAX_SIZE).unwrap() + 1;
1969
1970            // SO_SNDBUF and SO_RCVBUF clamp the size to the limit.
1971            socket
1972                .setsockopt(current_task, SOL_SOCKET, SO_SNDBUF, sock_opt_value(too_big))
1973                .expect("setsockopt should succeed");
1974            socket
1975                .setsockopt(current_task, SOL_SOCKET, SO_RCVBUF, sock_opt_value(too_big))
1976                .expect("setsockopt should succeed");
1977            let expected_max = u32::try_from(SOCKET_MAX_SIZE).unwrap();
1978            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_SNDBUF), expected_max);
1979            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_RCVBUF), expected_max);
1980
1981            // SO_SNDBUFFORCE and SO_RCVBUFFORCE do not.
1982            // Note that the applied value is two times the requested value.
1983            socket
1984                .setsockopt(current_task, SOL_SOCKET, SO_SNDBUFFORCE, sock_opt_value(too_big))
1985                .expect("setsockopt should succeed");
1986            socket
1987                .setsockopt(current_task, SOL_SOCKET, SO_RCVBUFFORCE, sock_opt_value(too_big))
1988                .expect("setsockopt should succeed");
1989            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_SNDBUF), too_big * 2);
1990            assert_eq!(getsockopt_u32(&socket, SOL_SOCKET, SO_RCVBUF), too_big * 2);
1991        })
1992        .await;
1993    }
1994}