netstack3_ip/
raw.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Facilities backing raw IP sockets.
6
7use alloc::collections::btree_map::Entry;
8use alloc::collections::{BTreeMap, HashMap};
9use core::fmt::{self, Debug, Display};
10use core::num::NonZeroU8;
11use derivative::Derivative;
12use log::debug;
13use net_types::ip::{GenericOverIp, Ip, IpVersionMarker, Mtu};
14use net_types::{SpecifiedAddr, ZonedAddr};
15use netstack3_base::socket::{DualStackIpExt, DualStackRemoteIp, SocketZonedAddrExt as _};
16use netstack3_base::sync::{PrimaryRc, StrongRc, WeakRc};
17use netstack3_base::{
18    AnyDevice, ContextPair, DeviceIdContext, Inspector, InspectorDeviceExt, InspectorExt,
19    IpDeviceAddr, IpExt, Mark, MarkDomain, Marks, ReferenceNotifiers, ReferenceNotifiersExt as _,
20    RemoveResourceResultWithContext, ResourceCounterContext, StrongDeviceIdentifier,
21    TxMetadataBindingsTypes, WeakDeviceIdentifier, ZonedAddressError,
22};
23use netstack3_filter::RawIpBody;
24use packet::{BufferMut, SliceBufViewMut};
25use packet_formats::icmp;
26use packet_formats::ip::{DscpAndEcn, IpPacket};
27use zerocopy::SplitByteSlice;
28
29use crate::internal::raw::counters::RawIpSocketCounters;
30use crate::internal::raw::filter::RawIpSocketIcmpFilter;
31use crate::internal::raw::protocol::RawIpSocketProtocol;
32use crate::internal::raw::state::{RawIpSocketLockedState, RawIpSocketState};
33use crate::internal::socket::{SendOneShotIpPacketError, SocketHopLimits};
34use crate::socket::{
35    IpSockCreateAndSendError, IpSocketHandler, RouteResolutionOptions, SendOptions,
36};
37use crate::DEFAULT_HOP_LIMITS;
38
39mod checksum;
40pub(crate) mod counters;
41pub(crate) mod filter;
42pub(crate) mod protocol;
43pub(crate) mod state;
44
45/// Types provided by bindings used in the raw IP socket implementation.
46pub trait RawIpSocketsBindingsTypes: TxMetadataBindingsTypes {
47    /// The bindings state (opaque to core) associated with a socket.
48    type RawIpSocketState<I: Ip>: Send + Sync + Debug;
49}
50
51/// Functionality provided by bindings used in the raw IP socket implementation.
52pub trait RawIpSocketsBindingsContext<I: IpExt, D: StrongDeviceIdentifier>:
53    RawIpSocketsBindingsTypes + Sized
54{
55    /// Called for each received IP packet that matches the provided socket.
56    fn receive_packet<B: SplitByteSlice>(
57        &self,
58        socket: &RawIpSocketId<I, D::Weak, Self>,
59        packet: &I::Packet<B>,
60        device: &D,
61    );
62}
63
64/// The raw IP socket API.
65pub struct RawIpSocketApi<I: Ip, C> {
66    ctx: C,
67    _ip_mark: IpVersionMarker<I>,
68}
69
70impl<I: Ip, C> RawIpSocketApi<I, C> {
71    /// Constructs a new RAW IP socket API.
72    pub fn new(ctx: C) -> Self {
73        Self { ctx, _ip_mark: IpVersionMarker::new() }
74    }
75}
76
77impl<I: IpExt + DualStackIpExt, C> RawIpSocketApi<I, C>
78where
79    C: ContextPair,
80    C::BindingsContext: RawIpSocketsBindingsTypes + ReferenceNotifiers + 'static,
81    C::CoreContext: RawIpSocketMapContext<I, C::BindingsContext>
82        + RawIpSocketStateContext<I, C::BindingsContext>
83        + ResourceCounterContext<RawIpApiSocketId<I, C>, RawIpSocketCounters<I>>,
84{
85    fn core_ctx(&mut self) -> &mut C::CoreContext {
86        let Self { ctx, _ip_mark } = self;
87        ctx.core_ctx()
88    }
89
90    fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
91        let Self { ctx, _ip_mark } = self;
92        ctx.contexts()
93    }
94
95    /// Creates a raw IP socket for the given protocol.
96    pub fn create(
97        &mut self,
98        protocol: RawIpSocketProtocol<I>,
99        external_state: <C::BindingsContext as RawIpSocketsBindingsTypes>::RawIpSocketState<I>,
100    ) -> RawIpApiSocketId<I, C> {
101        let socket =
102            PrimaryRawIpSocketId(PrimaryRc::new(RawIpSocketState::new(protocol, external_state)));
103        let strong = self.core_ctx().with_socket_map_mut(|socket_map| socket_map.insert(socket));
104        debug!("created raw IP socket {strong:?}, on protocol {protocol:?}");
105
106        if protocol.requires_system_checksums() {
107            self.core_ctx().with_locked_state_mut(&strong, |state| state.system_checksums = true)
108        }
109
110        strong
111    }
112
113    /// Removes the raw IP socket from the system, returning its external state.
114    pub fn close(
115        &mut self,
116        id: RawIpApiSocketId<I, C>,
117    ) -> RemoveResourceResultWithContext<
118        <C::BindingsContext as RawIpSocketsBindingsTypes>::RawIpSocketState<I>,
119        C::BindingsContext,
120    > {
121        let primary = self.core_ctx().with_socket_map_mut(|socket_map| socket_map.remove(id));
122        debug!("removed raw IP socket {primary:?}");
123        let PrimaryRawIpSocketId(primary) = primary;
124
125        C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
126            primary,
127            |state: RawIpSocketState<I, _, C::BindingsContext>| state.into_external_state(),
128        )
129    }
130
131    /// Sends an IP packet on the raw IP socket to the provided destination.
132    ///
133    /// The provided `body` is not expected to include an IP header; a system
134    /// determined header will automatically be applied.
135    pub fn send_to<B: BufferMut>(
136        &mut self,
137        id: &RawIpApiSocketId<I, C>,
138        remote_ip: Option<
139            ZonedAddr<
140                SpecifiedAddr<I::Addr>,
141                <C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId,
142            >,
143        >,
144        mut body: B,
145    ) -> Result<(), RawIpSocketSendToError> {
146        match id.protocol() {
147            RawIpSocketProtocol::Raw => return Err(RawIpSocketSendToError::ProtocolRaw),
148            RawIpSocketProtocol::Proto(_) => {}
149        }
150        // TODO(https://fxbug.dev/339692009): Return an error if IP_HDRINCL is
151        // set.
152
153        // TODO(https://fxbug.dev/342579393): Use the socket's bound address.
154        let local_ip = None;
155
156        let remote_ip = match DualStackRemoteIp::<I, _>::new(remote_ip) {
157            DualStackRemoteIp::ThisStack(addr) => addr,
158            DualStackRemoteIp::OtherStack(_addr) => {
159                return Err(RawIpSocketSendToError::MappedRemoteIp)
160            }
161        };
162        let protocol = id.protocol().proto();
163
164        let (core_ctx, bindings_ctx) = self.contexts();
165        let result = core_ctx.with_locked_state_and_socket_handler(id, |state, core_ctx| {
166            let RawIpSocketLockedState {
167                bound_device,
168                icmp_filter: _,
169                hop_limits,
170                multicast_loop,
171                system_checksums,
172                marks,
173            } = state;
174            let (remote_ip, device) = remote_ip
175                .resolve_addr_with_device(bound_device.clone())
176                .map_err(RawIpSocketSendToError::Zone)?;
177            let send_options = RawIpSocketOptions {
178                hop_limits: &hop_limits,
179                multicast_loop: *multicast_loop,
180                marks: &marks,
181            };
182
183            let build_packet_fn =
184                |src_ip: IpDeviceAddr<I::Addr>| -> Result<RawIpBody<_, _>, RawIpSocketSendToError> {
185                    if *system_checksums {
186                        let buf = SliceBufViewMut::new(body.as_mut());
187                        if !checksum::populate_checksum::<I, _>(
188                            src_ip.addr(),
189                            remote_ip.addr(),
190                            protocol,
191                            buf,
192                        ) {
193                            return Err(RawIpSocketSendToError::InvalidBody);
194                        }
195                    }
196                    Ok(RawIpBody::new(protocol, src_ip.addr(), remote_ip.addr(), body))
197                };
198
199            // TODO(https://fxbug.dev/392111277): Enforce send buffer for raw ip
200            // sockets.
201            let tx_metadata = Default::default();
202
203            core_ctx
204                .send_oneshot_ip_packet_with_fallible_serializer(
205                    bindings_ctx,
206                    device.as_ref().map(|d| d.as_ref()),
207                    local_ip,
208                    remote_ip,
209                    protocol,
210                    &send_options,
211                    tx_metadata,
212                    build_packet_fn,
213                )
214                .map_err(|e| match e {
215                    SendOneShotIpPacketError::CreateAndSendError { err } => {
216                        RawIpSocketSendToError::Ip(err)
217                    }
218                    SendOneShotIpPacketError::SerializeError(err) => err,
219                })
220        });
221        match &result {
222            Ok(()) => core_ctx
223                .increment_both(&id, |counters: &RawIpSocketCounters<I>| &counters.tx_packets),
224            Err(RawIpSocketSendToError::InvalidBody) => core_ctx
225                .increment_both(&id, |counters: &RawIpSocketCounters<I>| {
226                    &counters.tx_checksum_errors
227                }),
228            Err(_) => {}
229        }
230        result
231    }
232
233    // TODO(https://fxbug.dev/342577389): Add a `send` function that does not
234    // require a remote_ip to support sending on connected sockets.
235    // TODO(https://fxbug.dev/339692009): Add a `send` function that does not
236    // require a remote_ip to support sending when the remote_ip is provided via
237    // IP_HDRINCL.
238
239    /// Sets the socket's bound device, returning the original value.
240    pub fn set_device(
241        &mut self,
242        id: &RawIpApiSocketId<I, C>,
243        device: Option<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
244    ) -> Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
245        let device = device.map(|strong| strong.downgrade());
246        // TODO(https://fxbug.dev/342579393): Verify the device is compatible
247        // with the socket's bound address.
248        // TODO(https://fxbug.dev/342577389): Verify the device is compatible
249        // with the socket's peer address.
250        self.core_ctx()
251            .with_locked_state_mut(id, |state| core::mem::replace(&mut state.bound_device, device))
252    }
253
254    /// Gets the socket's bound device,
255    pub fn get_device(
256        &mut self,
257        id: &RawIpApiSocketId<I, C>,
258    ) -> Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
259        self.core_ctx().with_locked_state(id, |state| state.bound_device.clone())
260    }
261
262    /// Sets the socket's ICMP filter, returning the original value.
263    ///
264    /// Note, if the socket's protocol is not compatible (e.g. ICMPv4 for an
265    /// IPv4 socket, or ICMPv6 for an IPv6 socket), an error is returned.
266    pub fn set_icmp_filter(
267        &mut self,
268        id: &RawIpApiSocketId<I, C>,
269        filter: Option<RawIpSocketIcmpFilter<I>>,
270    ) -> Result<Option<RawIpSocketIcmpFilter<I>>, RawIpSocketIcmpFilterError> {
271        debug!("setting ICMP Filter on {id:?}: {filter:?}");
272        if !id.protocol().is_icmp() {
273            return Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp);
274        }
275        Ok(self
276            .core_ctx()
277            .with_locked_state_mut(id, |state| core::mem::replace(&mut state.icmp_filter, filter)))
278    }
279
280    /// Gets the socket's ICMP
281    ///
282    /// Note, if the socket's protocol is not compatible (e.g. ICMPv4 for an
283    /// IPv4 socket, or ICMPv6 for an IPv6 socket), an error is returned.
284    pub fn get_icmp_filter(
285        &mut self,
286        id: &RawIpApiSocketId<I, C>,
287    ) -> Result<Option<RawIpSocketIcmpFilter<I>>, RawIpSocketIcmpFilterError> {
288        if !id.protocol().is_icmp() {
289            return Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp);
290        }
291        Ok(self.core_ctx().with_locked_state(id, |state| state.icmp_filter.clone()))
292    }
293
294    /// Sets the socket's unicast hop limit, returning the original value.
295    ///
296    /// If `None` is provided, the hop limit will be restored to the system
297    /// default.
298    pub fn set_unicast_hop_limit(
299        &mut self,
300        id: &RawIpApiSocketId<I, C>,
301        new_limit: Option<NonZeroU8>,
302    ) -> Option<NonZeroU8> {
303        self.core_ctx().with_locked_state_mut(id, |state| {
304            core::mem::replace(&mut state.hop_limits.unicast, new_limit)
305        })
306    }
307
308    /// Gets the socket's unicast hop limit, or the system default, if unset.
309    pub fn get_unicast_hop_limit(&mut self, id: &RawIpApiSocketId<I, C>) -> NonZeroU8 {
310        self.core_ctx().with_locked_state(id, |state| {
311            state.hop_limits.get_limits_with_defaults(&DEFAULT_HOP_LIMITS).unicast
312        })
313    }
314
315    /// Sets the socket's multicast hop limit, returning the original value.
316    ///
317    /// If `None` is provided, the hop limit will be restored to the system
318    /// default.
319    pub fn set_multicast_hop_limit(
320        &mut self,
321        id: &RawIpApiSocketId<I, C>,
322        new_limit: Option<NonZeroU8>,
323    ) -> Option<NonZeroU8> {
324        self.core_ctx().with_locked_state_mut(id, |state| {
325            core::mem::replace(&mut state.hop_limits.multicast, new_limit)
326        })
327    }
328
329    /// Gets the socket's multicast hop limit, or the system default, if unset.
330    pub fn get_multicast_hop_limit(&mut self, id: &RawIpApiSocketId<I, C>) -> NonZeroU8 {
331        self.core_ctx().with_locked_state(id, |state| {
332            state.hop_limits.get_limits_with_defaults(&DEFAULT_HOP_LIMITS).multicast
333        })
334    }
335
336    /// Sets `multicast_loop` on the socket, returning the original value.
337    ///
338    /// When true, the socket will loop back all sent multicast traffic.
339    pub fn set_multicast_loop(&mut self, id: &RawIpApiSocketId<I, C>, value: bool) -> bool {
340        self.core_ctx()
341            .with_locked_state_mut(id, |state| core::mem::replace(&mut state.multicast_loop, value))
342    }
343
344    /// Gets the `multicast_loop` value on the socket.
345    pub fn get_multicast_loop(&mut self, id: &RawIpApiSocketId<I, C>) -> bool {
346        self.core_ctx().with_locked_state(id, |state| state.multicast_loop)
347    }
348
349    /// Sets the socket mark for the socket domain.
350    pub fn set_mark(&mut self, id: &RawIpApiSocketId<I, C>, domain: MarkDomain, mark: Mark) {
351        self.core_ctx().with_locked_state_mut(id, |state| {
352            *state.marks.get_mut(domain) = mark;
353        })
354    }
355
356    /// Gets the socket mark for the socket domain.
357    pub fn get_mark(&mut self, id: &RawIpApiSocketId<I, C>, domain: MarkDomain) -> Mark {
358        self.core_ctx().with_locked_state(id, |state| *state.marks.get(domain))
359    }
360
361    /// Provides inspect data for raw IP sockets.
362    pub fn inspect<N>(&mut self, inspector: &mut N)
363    where
364        N: Inspector
365            + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
366    {
367        self.core_ctx().with_socket_map_and_state_ctx(|socket_map, core_ctx| {
368            socket_map.iter_sockets().for_each(|socket| {
369                inspector.record_debug_child(socket, |node| {
370                    node.record_display("TransportProtocol", socket.protocol().proto());
371                    node.record_str("NetworkProtocol", I::NAME);
372                    // TODO(https://fxbug.dev/342579393): Support `bind`.
373                    node.record_local_socket_addr::<
374                        N,
375                        I::Addr,
376                        <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
377                        NoPortMarker,
378                    >(None);
379                    // TODO(https://fxbug.dev/342577389): Support `connect`.
380                    node.record_remote_socket_addr::<
381                        N,
382                        I::Addr,
383                        <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
384                        NoPortMarker,
385                    >(None);
386                    core_ctx.with_locked_state(socket, |state| {
387                        let RawIpSocketLockedState {
388                            bound_device,
389                            icmp_filter,
390                            hop_limits: _,
391                            multicast_loop: _,
392                            marks: _,
393                            system_checksums: _,
394                        } = state;
395                        if let Some(bound_device) = bound_device {
396                            N::record_device(node, "BoundDevice", bound_device);
397                        } else {
398                            node.record_str("BoundDevice", "None");
399                        }
400                        if let Some(icmp_filter) = icmp_filter {
401                            node.record_display("IcmpFilter", icmp_filter);
402                        } else {
403                            node.record_str("IcmpFilter", "None");
404                        }
405                    });
406                    node.record_child("Counters", |node| {
407                        node.delegate_inspectable(socket.state().counters())
408                    })
409                })
410            })
411        })
412    }
413}
414
415/// Errors that may occur when calling [`RawIpSocketApi::send_to`].
416#[derive(Debug)]
417pub enum RawIpSocketSendToError {
418    /// The socket's protocol is `RawIpSocketProtocol::Raw`, which disallows
419    /// `send_to` (the remote IP should be specified in the included header, not
420    /// as a separate address argument).
421    ProtocolRaw,
422    /// The provided remote_ip was an IPv4-mapped-IPv6 address. Dual stack
423    /// operations are not supported on raw IP sockets.
424    MappedRemoteIp,
425    /// The provided packet body was invalid, and could not be sent. Typically
426    /// originates when the stack is asked to inspect the packet body, e.g. to
427    /// compute and populate the checksum value.
428    InvalidBody,
429    /// There was an error when resolving the remote_ip's zone.
430    Zone(ZonedAddressError),
431    /// The IP layer failed to send the packet.
432    Ip(IpSockCreateAndSendError),
433}
434
435/// Errors that may occur getting/setting the ICMP filter for a raw IP socket.
436#[derive(Debug, PartialEq)]
437pub enum RawIpSocketIcmpFilterError {
438    /// The socket's protocol does not allow ICMP filters.
439    ProtocolNotIcmp,
440}
441
442/// The owner of socket state.
443struct PrimaryRawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
444    PrimaryRc<RawIpSocketState<I, D, BT>>,
445);
446
447impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
448    for PrimaryRawIpSocketId<I, D, BT>
449{
450    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
451        let Self(rc) = self;
452        f.debug_tuple("RawIpSocketId").field(&PrimaryRc::debug_id(rc)).finish()
453    }
454}
455
456/// Reference to the state of a live socket.
457#[derive(Derivative, GenericOverIp)]
458#[derivative(Clone(bound = ""), Eq(bound = ""), Hash(bound = ""), PartialEq(bound = ""))]
459#[generic_over_ip(I, Ip)]
460pub struct RawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
461    StrongRc<RawIpSocketState<I, D, BT>>,
462);
463
464impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketId<I, D, BT> {
465    /// Return the bindings state associated with this socket.
466    pub fn external_state(&self) -> &BT::RawIpSocketState<I> {
467        let RawIpSocketId(strong_rc) = self;
468        strong_rc.external_state()
469    }
470    /// Return the protocol associated with this socket.
471    pub fn protocol(&self) -> &RawIpSocketProtocol<I> {
472        let RawIpSocketId(strong_rc) = self;
473        strong_rc.protocol()
474    }
475    /// Downgrades this ID to a weak reference.
476    pub fn downgrade(&self) -> WeakRawIpSocketId<I, D, BT> {
477        let Self(rc) = self;
478        WeakRawIpSocketId(StrongRc::downgrade(rc))
479    }
480    /// Gets the socket state.
481    pub fn state(&self) -> &RawIpSocketState<I, D, BT> {
482        let RawIpSocketId(strong_rc) = self;
483        &*strong_rc
484    }
485}
486
487impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
488    for RawIpSocketId<I, D, BT>
489{
490    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
491        let Self(rc) = self;
492        f.debug_tuple("RawIpSocketId").field(&StrongRc::debug_id(rc)).finish()
493    }
494}
495
496/// A weak reference to a raw IP socket.
497pub struct WeakRawIpSocketId<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes>(
498    WeakRc<RawIpSocketState<I, D, BT>>,
499);
500
501impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> Debug
502    for WeakRawIpSocketId<I, D, BT>
503{
504    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
505        let Self(rc) = self;
506        f.debug_tuple("WeakRawIpSocketId").field(&WeakRc::debug_id(rc)).finish()
507    }
508}
509
510/// An alias for [`RawIpSocketId`] in [`RawIpSocketApi`], for brevity.
511type RawIpApiSocketId<I, C> = RawIpSocketId<
512    I,
513    <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
514    <C as ContextPair>::BindingsContext,
515>;
516
517/// Provides access to the [`RawIpSocketLockedState`] for a raw IP socket.
518///
519/// Implementations must ensure a proper lock ordering is adhered to.
520pub trait RawIpSocketStateContext<I: IpExt, BT: RawIpSocketsBindingsTypes>:
521    DeviceIdContext<AnyDevice>
522{
523    /// The implementation of `IpSocketHandler` available after having locked
524    /// the state for an individual socket.
525    type SocketHandler<'a>: IpSocketHandler<
526        I,
527        BT,
528        DeviceId = Self::DeviceId,
529        WeakDeviceId = Self::WeakDeviceId,
530    >;
531
532    /// Calls the callback with an immutable reference to the socket's locked
533    /// state.
534    fn with_locked_state<O, F: FnOnce(&RawIpSocketLockedState<I, Self::WeakDeviceId>) -> O>(
535        &mut self,
536        id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
537        cb: F,
538    ) -> O;
539
540    /// Calls the callback with an immutable reference to the socket's locked
541    /// state and the `SocketHandler`.
542    fn with_locked_state_and_socket_handler<
543        O,
544        F: FnOnce(&RawIpSocketLockedState<I, Self::WeakDeviceId>, &mut Self::SocketHandler<'_>) -> O,
545    >(
546        &mut self,
547        id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
548        cb: F,
549    ) -> O;
550
551    /// Calls the callback with a mutable reference to the socket's locked
552    /// state.
553    fn with_locked_state_mut<
554        O,
555        F: FnOnce(&mut RawIpSocketLockedState<I, Self::WeakDeviceId>) -> O,
556    >(
557        &mut self,
558        id: &RawIpSocketId<I, Self::WeakDeviceId, BT>,
559        cb: F,
560    ) -> O;
561}
562
563/// The collection of all raw IP sockets installed in the system.
564///
565/// Implementations must ensure a proper lock ordering is adhered to.
566#[derive(Derivative)]
567#[derivative(Default(bound = ""))]
568pub struct RawIpSocketMap<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> {
569    /// All sockets installed in the system.
570    ///
571    /// This is a nested collection, with the outer `BTreeMap` indexable by the
572    /// socket's protocol, which allows for more efficient delivery of received
573    /// IP packets.
574    ///
575    /// NB: The inner map is a `HashMap` keyed by strong IDs, rather than an
576    /// `HashSet` keyed by primary IDs, because it would be impossible to build
577    /// a lookup key for the hashset (there can only ever exist 1 primary ID,
578    /// which is *in* the set).
579    sockets: BTreeMap<
580        RawIpSocketProtocol<I>,
581        HashMap<RawIpSocketId<I, D, BT>, PrimaryRawIpSocketId<I, D, BT>>,
582    >,
583}
584
585impl<I: IpExt, D: WeakDeviceIdentifier, BT: RawIpSocketsBindingsTypes> RawIpSocketMap<I, D, BT> {
586    fn insert(&mut self, socket: PrimaryRawIpSocketId<I, D, BT>) -> RawIpSocketId<I, D, BT> {
587        let RawIpSocketMap { sockets } = self;
588        let PrimaryRawIpSocketId(primary) = &socket;
589        let strong = RawIpSocketId(PrimaryRc::clone_strong(primary));
590        // NB: The socket must be newly inserted because there can only ever
591        // be a single primary ID for a socket.
592        assert!(sockets
593            .entry(*strong.protocol())
594            .or_default()
595            .insert(strong.clone(), socket)
596            .is_none());
597        strong
598    }
599
600    fn remove(&mut self, socket: RawIpSocketId<I, D, BT>) -> PrimaryRawIpSocketId<I, D, BT> {
601        // NB: This function asserts on the presence of `protocol` in the
602        // outer map, and the `socket` in the inner map.  The strong ID is
603        // witness to the liveness of socket.
604        let RawIpSocketMap { sockets } = self;
605        let protocol = *socket.protocol();
606        match sockets.entry(protocol) {
607            Entry::Vacant(_) => unreachable!(
608                "{socket:?} with protocol {protocol:?} must be present in the socket map"
609            ),
610            Entry::Occupied(mut entry) => {
611                let map = entry.get_mut();
612                let primary = map.remove(&socket).unwrap();
613                // NB: If this was the last socket for this protocol, remove
614                // the entry from the outer `BTreeMap`.
615                if map.is_empty() {
616                    let _: HashMap<RawIpSocketId<I, D, BT>, PrimaryRawIpSocketId<I, D, BT>> =
617                        entry.remove();
618                }
619                primary
620            }
621        }
622    }
623
624    fn iter_sockets(&self) -> impl Iterator<Item = &RawIpSocketId<I, D, BT>> {
625        let RawIpSocketMap { sockets } = self;
626        sockets.values().flat_map(|sockets_by_protocol| sockets_by_protocol.keys())
627    }
628
629    fn iter_sockets_for_protocol(
630        &self,
631        protocol: &RawIpSocketProtocol<I>,
632    ) -> impl Iterator<Item = &RawIpSocketId<I, D, BT>> {
633        let RawIpSocketMap { sockets } = self;
634        sockets.get(protocol).map(|sockets| sockets.keys()).into_iter().flatten()
635    }
636}
637
638/// A type that provides access to the `RawIpSocketMap` used by the system.
639pub trait RawIpSocketMapContext<I: IpExt, BT: RawIpSocketsBindingsTypes>:
640    DeviceIdContext<AnyDevice>
641{
642    /// The implementation of `RawIpSocketStateContext` available after having
643    /// accessed the system's socket map.
644    type StateCtx<'a>: RawIpSocketStateContext<I, BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
645        + ResourceCounterContext<RawIpSocketId<I, Self::WeakDeviceId, BT>, RawIpSocketCounters<I>>;
646
647    /// Calls the callback with an immutable reference to the socket map.
648    fn with_socket_map_and_state_ctx<
649        O,
650        F: FnOnce(&RawIpSocketMap<I, Self::WeakDeviceId, BT>, &mut Self::StateCtx<'_>) -> O,
651    >(
652        &mut self,
653        cb: F,
654    ) -> O;
655    /// Calls the callback with a mutable reference to the socket map.
656    fn with_socket_map_mut<O, F: FnOnce(&mut RawIpSocketMap<I, Self::WeakDeviceId, BT>) -> O>(
657        &mut self,
658        cb: F,
659    ) -> O;
660}
661
662/// A type that provides the raw IP socket functionality required by core.
663pub trait RawIpSocketHandler<I: IpExt, BC>: DeviceIdContext<AnyDevice> {
664    /// Deliver a received IP packet to all appropriate raw IP sockets.
665    fn deliver_packet_to_raw_ip_sockets<B: SplitByteSlice>(
666        &mut self,
667        bindings_ctx: &mut BC,
668        packet: &I::Packet<B>,
669        device: &Self::DeviceId,
670    );
671}
672
673impl<I, BC, CC> RawIpSocketHandler<I, BC> for CC
674where
675    I: IpExt,
676    BC: RawIpSocketsBindingsContext<I, CC::DeviceId>,
677    CC: RawIpSocketMapContext<I, BC>,
678{
679    fn deliver_packet_to_raw_ip_sockets<B: SplitByteSlice>(
680        &mut self,
681        bindings_ctx: &mut BC,
682        packet: &I::Packet<B>,
683        device: &CC::DeviceId,
684    ) {
685        let protocol = RawIpSocketProtocol::new(packet.proto());
686
687        // NB: sockets with `RawIpSocketProtocol::Raw` are send only, and cannot
688        // receive packets.
689        match protocol {
690            RawIpSocketProtocol::Raw => {
691                debug!("received IP packet with raw protocol (IANA Reserved - 255); dropping");
692                return;
693            }
694            RawIpSocketProtocol::Proto(_) => {}
695        };
696
697        self.with_socket_map_and_state_ctx(|socket_map, core_ctx| {
698            socket_map.iter_sockets_for_protocol(&protocol).for_each(|socket| {
699                match core_ctx.with_locked_state(socket, |state| {
700                    check_packet_for_delivery(packet, device, state)
701                }) {
702                    DeliveryOutcome::Deliver => {
703                        core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
704                            &counters.rx_packets
705                        });
706                        bindings_ctx.receive_packet(socket, packet, device);
707                    }
708                    DeliveryOutcome::WrongChecksum => {
709                        core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
710                            &counters.rx_checksum_errors
711                        });
712                    }
713                    DeliveryOutcome::WrongIcmpMessageType => {
714                        core_ctx.increment_both(&socket, |counters: &RawIpSocketCounters<I>| {
715                            &counters.rx_icmp_filtered
716                        });
717                    }
718                    DeliveryOutcome::WrongDevice => {}
719                }
720            })
721        })
722    }
723}
724
725/// Represents whether an IP packet should be delivered to a socket.
726enum DeliveryOutcome {
727    /// The packet should be delivered.
728    Deliver,
729    /// Don't deliver. The packet was received on an incorrect device.
730    WrongDevice,
731    /// Don't deliver. The packet does not have a valid checksum.
732    WrongChecksum,
733    /// Don't deliver. The packet's inner ICMP message type does not pass the
734    /// socket's ICMP filter.
735    WrongIcmpMessageType,
736}
737
738/// Returns whether the given packet should be delivered to the given socket.
739fn check_packet_for_delivery<I: IpExt, D: StrongDeviceIdentifier, B: SplitByteSlice>(
740    packet: &I::Packet<B>,
741    device: &D,
742    socket: &RawIpSocketLockedState<I, D::Weak>,
743) -> DeliveryOutcome {
744    let RawIpSocketLockedState {
745        bound_device,
746        icmp_filter,
747        hop_limits: _,
748        marks: _,
749        multicast_loop: _,
750        system_checksums,
751    } = socket;
752    // Verify the received device matches the socket's bound device, if any.
753    if bound_device.as_ref().is_some_and(|bound_device| bound_device != device) {
754        return DeliveryOutcome::WrongDevice;
755    }
756
757    // Verify the inner message's checksum, if requested.
758    // NB: The checksum was not previously validated by the IP layer, because
759    // packets are delivered to raw sockets before the IP layer attempts to
760    // parse the inner message.
761    if *system_checksums && !checksum::has_valid_checksum::<I, B>(packet) {
762        return DeliveryOutcome::WrongChecksum;
763    }
764
765    // Verify the packet passes the socket's icmp_filter, if any.
766    if icmp_filter.as_ref().is_some_and(|icmp_filter| {
767        // NB: If the socket has an icmp_filter, its protocol must be ICMP.
768        // That means the packet must be ICMP, because we're considering
769        // delivering it to this socket.
770        debug_assert!(RawIpSocketProtocol::<I>::new(packet.proto()).is_icmp());
771        match icmp::peek_message_type(packet.body()) {
772            // NB: The peek call above will fail if 1) the body doesn't have
773            // enough bytes to be an ICMP header, or if 2) the message_type from
774            // the header is unrecognized. In either case, don't deliver the
775            // packet. This is consistent with Linux in the first case, but not
776            // the second (e.g. linux *will* deliver the packet if it has an
777            // invalid ICMP message type). This divergence is not expected to be
778            // problematic for clients, and as such it is kept for the improved
779            // type safety when operating on a known to be valid ICMP message
780            // type.
781            Err(_) => true,
782            Ok(message_type) => !icmp_filter.allows_type(message_type),
783        }
784    }) {
785        return DeliveryOutcome::WrongIcmpMessageType;
786    }
787
788    DeliveryOutcome::Deliver
789}
790
791/// An implementation of [`SendOptions`] for raw IP sockets.
792struct RawIpSocketOptions<'a, I: Ip> {
793    hop_limits: &'a SocketHopLimits<I>,
794    multicast_loop: bool,
795    marks: &'a Marks,
796}
797
798impl<I: Ip> RouteResolutionOptions<I> for RawIpSocketOptions<'_, I> {
799    fn transparent(&self) -> bool {
800        false
801    }
802
803    fn marks(&self) -> &Marks {
804        self.marks
805    }
806}
807
808impl<I: IpExt> SendOptions<I> for RawIpSocketOptions<'_, I> {
809    fn hop_limit(&self, destination: &SpecifiedAddr<I::Addr>) -> Option<NonZeroU8> {
810        self.hop_limits.hop_limit_for_dst(destination)
811    }
812
813    fn multicast_loop(&self) -> bool {
814        self.multicast_loop
815    }
816
817    fn allow_broadcast(&self) -> Option<I::BroadcastMarker> {
818        None
819    }
820
821    fn dscp_and_ecn(&self) -> DscpAndEcn {
822        DscpAndEcn::default()
823    }
824
825    fn mtu(&self) -> Mtu {
826        Mtu::no_limit()
827    }
828}
829
830/// A marker type capturing that raw IP sockets don't have ports.
831struct NoPortMarker {}
832
833impl Display for NoPortMarker {
834    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
835        write!(f, "NoPort")
836    }
837}
838
839#[cfg(test)]
840mod test {
841    use super::*;
842
843    use alloc::rc::Rc;
844    use alloc::vec;
845    use alloc::vec::Vec;
846    use assert_matches::assert_matches;
847    use core::cell::RefCell;
848    use core::convert::Infallible as Never;
849    use core::marker::PhantomData;
850    use ip_test_macro::ip_test;
851    use net_types::ip::{IpVersion, Ipv4, Ipv6};
852    use netstack3_base::sync::{DynDebugReferences, Mutex};
853    use netstack3_base::testutil::{
854        FakeStrongDeviceId, FakeTxMetadata, FakeWeakDeviceId, MultipleDevicesId, TestIpExt,
855    };
856    use netstack3_base::{ContextProvider, CounterContext, CtxPair};
857    use packet::{Buf, InnerPacketBuilder as _, ParseBuffer as _, Serializer as _};
858    use packet_formats::icmp::{
859        IcmpEchoReply, IcmpMessage, IcmpPacketBuilder, IcmpZeroCode, Icmpv6MessageType,
860    };
861    use packet_formats::ip::{IpPacketBuilder, IpProto, IpProtoExt, Ipv6Proto};
862    use packet_formats::ipv6::Ipv6Packet;
863    use test_case::test_case;
864
865    use crate::internal::socket::testutil::{FakeIpSocketCtx, InnerFakeIpSocketCtx};
866    use crate::socket::testutil::FakeDeviceConfig;
867    use crate::{SendIpPacketMeta, DEFAULT_HOP_LIMITS};
868
869    #[derive(Derivative, Debug)]
870    #[derivative(Default(bound = ""))]
871    struct FakeExternalSocketState<D> {
872        /// The collection of IP packets received on this socket.
873        received_packets: Mutex<Vec<ReceivedIpPacket<D>>>,
874    }
875
876    #[derive(Debug, PartialEq)]
877    struct ReceivedIpPacket<D> {
878        data: Vec<u8>,
879        device: D,
880    }
881
882    #[derive(Derivative)]
883    #[derivative(Default(bound = ""))]
884    struct FakeBindingsCtx<D> {
885        _device_id_type: PhantomData<D>,
886    }
887
888    /// State required to test raw IP sockets. Held by `FakeCoreCtx`.
889    struct FakeCoreCtxState<I: IpExt, D: FakeStrongDeviceId> {
890        // NB: Hold in an `Rc<RefCell<...>>` to switch to runtime borrow
891        // checking. This allows us to borrow the socket map at the same time
892        // as the outer `FakeCoreCtx` is mutably borrowed (Required to implement
893        // `RawIpSocketMapContext::with_socket_map_and_state_ctx`).
894        socket_map: Rc<RefCell<RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>>>,
895        /// An inner fake implementation of `IpSocketHandler`. By implementing
896        /// `InnerFakeIpSocketCtx` below, the `FakeCoreCtx` will be eligible for
897        /// a blanket impl of `IpSocketHandler`.
898        ip_socket_ctx: FakeIpSocketCtx<I, D>,
899        /// The aggregate counters for raw ip sockets.
900        counters: RawIpSocketCounters<I>,
901    }
902
903    impl<I: IpExt, D: FakeStrongDeviceId> InnerFakeIpSocketCtx<I, D> for FakeCoreCtxState<I, D> {
904        fn fake_ip_socket_ctx_mut(&mut self) -> &mut FakeIpSocketCtx<I, D> {
905            &mut self.ip_socket_ctx
906        }
907    }
908
909    type FakeCoreCtx<I, D> = netstack3_base::testutil::FakeCoreCtx<
910        FakeCoreCtxState<I, D>,
911        SendIpPacketMeta<I, D, SpecifiedAddr<<I as Ip>::Addr>>,
912        D,
913    >;
914
915    impl<D: FakeStrongDeviceId> TxMetadataBindingsTypes for FakeBindingsCtx<D> {
916        type TxMetadata = FakeTxMetadata;
917    }
918
919    impl<D: FakeStrongDeviceId> RawIpSocketsBindingsTypes for FakeBindingsCtx<D> {
920        type RawIpSocketState<I: Ip> = FakeExternalSocketState<D>;
921    }
922
923    impl<I: IpExt, D: Copy + FakeStrongDeviceId> RawIpSocketsBindingsContext<I, D>
924        for FakeBindingsCtx<D>
925    {
926        fn receive_packet<B: SplitByteSlice>(
927            &self,
928            socket: &RawIpSocketId<I, D::Weak, Self>,
929            packet: &I::Packet<B>,
930            device: &D,
931        ) {
932            let packet = ReceivedIpPacket { data: packet.to_vec(), device: *device };
933            let FakeExternalSocketState { received_packets } = socket.external_state();
934            received_packets.lock().push(packet);
935        }
936    }
937
938    impl<I: IpExt, D: FakeStrongDeviceId> RawIpSocketStateContext<I, FakeBindingsCtx<D>>
939        for FakeCoreCtx<I, D>
940    {
941        type SocketHandler<'a> = FakeCoreCtx<I, D>;
942        fn with_locked_state<O, F: FnOnce(&RawIpSocketLockedState<I, D::Weak>) -> O>(
943            &mut self,
944            id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
945            cb: F,
946        ) -> O {
947            let RawIpSocketId(state_rc) = id;
948            let guard = state_rc.locked_state().read();
949            cb(&guard)
950        }
951        fn with_locked_state_and_socket_handler<
952            O,
953            F: FnOnce(&RawIpSocketLockedState<I, D::Weak>, &mut Self::SocketHandler<'_>) -> O,
954        >(
955            &mut self,
956            id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
957            cb: F,
958        ) -> O {
959            let RawIpSocketId(state_rc) = id;
960            let guard = state_rc.locked_state().read();
961            cb(&guard, self)
962        }
963        fn with_locked_state_mut<O, F: FnOnce(&mut RawIpSocketLockedState<I, D::Weak>) -> O>(
964            &mut self,
965            id: &RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
966            cb: F,
967        ) -> O {
968            let RawIpSocketId(state_rc) = id;
969            let mut guard = state_rc.locked_state().write();
970            cb(&mut guard)
971        }
972    }
973
974    impl<I: IpExt, D: FakeStrongDeviceId> CounterContext<RawIpSocketCounters<I>> for FakeCoreCtx<I, D> {
975        fn counters(&self) -> &RawIpSocketCounters<I> {
976            &self.state.counters
977        }
978    }
979
980    impl<I: IpExt, D: FakeStrongDeviceId>
981        ResourceCounterContext<
982            RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
983            RawIpSocketCounters<I>,
984        > for FakeCoreCtx<I, D>
985    {
986        fn per_resource_counters<'a>(
987            &'a self,
988            socket: &'a RawIpSocketId<I, D::Weak, FakeBindingsCtx<D>>,
989        ) -> &'a RawIpSocketCounters<I> {
990            socket.state().counters()
991        }
992    }
993
994    impl<I: IpExt, D: FakeStrongDeviceId> RawIpSocketMapContext<I, FakeBindingsCtx<D>>
995        for FakeCoreCtx<I, D>
996    {
997        type StateCtx<'a> = FakeCoreCtx<I, D>;
998        fn with_socket_map_and_state_ctx<
999            O,
1000            F: FnOnce(&RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>, &mut Self::StateCtx<'_>) -> O,
1001        >(
1002            &mut self,
1003            cb: F,
1004        ) -> O {
1005            let socket_map = self.state.socket_map.clone();
1006            let borrow = socket_map.borrow();
1007            cb(&borrow, self)
1008        }
1009        fn with_socket_map_mut<
1010            O,
1011            F: FnOnce(&mut RawIpSocketMap<I, D::Weak, FakeBindingsCtx<D>>) -> O,
1012        >(
1013            &mut self,
1014            cb: F,
1015        ) -> O {
1016            cb(&mut self.state.socket_map.borrow_mut())
1017        }
1018    }
1019
1020    impl<D> ContextProvider for FakeBindingsCtx<D> {
1021        type Context = FakeBindingsCtx<D>;
1022        fn context(&mut self) -> &mut Self::Context {
1023            self
1024        }
1025    }
1026
1027    impl<D> ReferenceNotifiers for FakeBindingsCtx<D> {
1028        type ReferenceReceiver<T: 'static> = Never;
1029
1030        type ReferenceNotifier<T: Send + 'static> = Never;
1031
1032        fn new_reference_notifier<T: Send + 'static>(
1033            _debug_references: DynDebugReferences,
1034        ) -> (Self::ReferenceNotifier<T>, Self::ReferenceReceiver<T>) {
1035            unimplemented!("raw IP socket removal shouldn't be deferred in tests");
1036        }
1037    }
1038
1039    fn new_raw_ip_socket_api<I: IpExt + TestIpExt>() -> RawIpSocketApi<
1040        I,
1041        CtxPair<FakeCoreCtx<I, MultipleDevicesId>, FakeBindingsCtx<MultipleDevicesId>>,
1042    > {
1043        // Set up all devices with a local IP and a route to the remote IP.
1044        let device_configs = [MultipleDevicesId::A, MultipleDevicesId::B, MultipleDevicesId::C]
1045            .into_iter()
1046            .map(|device| FakeDeviceConfig {
1047                device,
1048                local_ips: vec![I::TEST_ADDRS.local_ip],
1049                remote_ips: vec![I::TEST_ADDRS.remote_ip],
1050            });
1051        let state = FakeCoreCtxState {
1052            socket_map: Default::default(),
1053            ip_socket_ctx: FakeIpSocketCtx::new(device_configs),
1054            counters: Default::default(),
1055        };
1056
1057        RawIpSocketApi::new(CtxPair::with_core_ctx(FakeCoreCtx::with_state(state)))
1058    }
1059
1060    /// Arbitrary data to put inside of an IP packet.
1061    const IP_BODY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
1062
1063    /// Constructs a buffer containing an IP packet with sensible defaults.
1064    fn new_ip_packet_buf<I: IpExt + TestIpExt>(
1065        ip_body: &[u8],
1066        proto: I::Proto,
1067    ) -> impl AsRef<[u8]> {
1068        const TTL: u8 = 255;
1069        ip_body
1070            .into_serializer()
1071            .encapsulate(I::PacketBuilder::new(
1072                *I::TEST_ADDRS.local_ip,
1073                *I::TEST_ADDRS.remote_ip,
1074                TTL,
1075                proto,
1076            ))
1077            .serialize_vec_outer()
1078            .unwrap()
1079    }
1080
1081    /// Construct a buffer containing an ICMP message with sensible defaults.
1082    fn new_icmp_message_buf<I: IpExt + TestIpExt, M: IcmpMessage<I> + Debug>(
1083        message: M,
1084        code: M::Code,
1085    ) -> impl AsRef<[u8]> {
1086        [].into_serializer()
1087            .encapsulate(IcmpPacketBuilder::new(
1088                *I::TEST_ADDRS.local_ip,
1089                *I::TEST_ADDRS.remote_ip,
1090                code,
1091                message,
1092            ))
1093            .serialize_vec_outer()
1094            .unwrap()
1095    }
1096
1097    #[ip_test(I)]
1098    #[test_case(IpProto::Udp; "UDP")]
1099    #[test_case(IpProto::Reserved; "IPPROTO_RAW")]
1100    fn create_and_close<I: IpExt + DualStackIpExt + TestIpExt>(proto: IpProto) {
1101        let mut api = new_raw_ip_socket_api::<I>();
1102        let sock = api.create(RawIpSocketProtocol::new(proto.into()), Default::default());
1103        let FakeExternalSocketState { received_packets: _ } = api.close(sock).into_removed();
1104    }
1105
1106    #[ip_test(I)]
1107    fn set_device<I: IpExt + DualStackIpExt + TestIpExt>() {
1108        let mut api = new_raw_ip_socket_api::<I>();
1109        let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1110
1111        assert_eq!(api.get_device(&sock), None);
1112        assert_eq!(api.set_device(&sock, Some(&MultipleDevicesId::A)), None);
1113        assert_eq!(api.get_device(&sock), Some(FakeWeakDeviceId(MultipleDevicesId::A)));
1114        assert_eq!(
1115            api.set_device(&sock, Some(&MultipleDevicesId::B)),
1116            Some(FakeWeakDeviceId(MultipleDevicesId::A))
1117        );
1118        assert_eq!(api.get_device(&sock), Some(FakeWeakDeviceId(MultipleDevicesId::B)));
1119        assert_eq!(api.set_device(&sock, None), Some(FakeWeakDeviceId(MultipleDevicesId::B)));
1120        assert_eq!(api.get_device(&sock), None);
1121    }
1122
1123    #[ip_test(I)]
1124    fn set_icmp_filter<I: IpExt + DualStackIpExt + TestIpExt>() {
1125        let filter1 = RawIpSocketIcmpFilter::<I>::new([123; 32]);
1126        let filter2 = RawIpSocketIcmpFilter::<I>::new([234; 32]);
1127        let mut api = new_raw_ip_socket_api::<I>();
1128
1129        let sock = api.create(RawIpSocketProtocol::new(I::ICMP_IP_PROTO), Default::default());
1130        assert_eq!(api.get_icmp_filter(&sock), Ok(None));
1131        assert_eq!(api.set_icmp_filter(&sock, Some(filter1.clone())), Ok(None));
1132        assert_eq!(api.get_icmp_filter(&sock), Ok(Some(filter1.clone())));
1133        assert_eq!(api.set_icmp_filter(&sock, Some(filter2.clone())), Ok(Some(filter1.clone())));
1134        assert_eq!(api.get_icmp_filter(&sock), Ok(Some(filter2.clone())));
1135        assert_eq!(api.set_icmp_filter(&sock, None), Ok(Some(filter2)));
1136        assert_eq!(api.get_icmp_filter(&sock), Ok(None));
1137
1138        // Sockets created with a non ICMP protocol cannot set an ICMP filter.
1139        let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1140        assert_eq!(
1141            api.set_icmp_filter(&sock, Some(filter1)),
1142            Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp)
1143        );
1144        assert_eq!(api.get_icmp_filter(&sock), Err(RawIpSocketIcmpFilterError::ProtocolNotIcmp));
1145    }
1146
1147    #[ip_test(I)]
1148    fn set_unicast_hop_limits<I: IpExt + DualStackIpExt + TestIpExt>() {
1149        let mut api = new_raw_ip_socket_api::<I>();
1150        let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1151
1152        let limit1 = NonZeroU8::new(1).unwrap();
1153        let limit2 = NonZeroU8::new(2).unwrap();
1154
1155        assert_eq!(api.get_unicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.unicast);
1156        assert_eq!(api.set_unicast_hop_limit(&sock, Some(limit1)), None);
1157        assert_eq!(api.get_unicast_hop_limit(&sock), limit1);
1158        assert_eq!(api.set_unicast_hop_limit(&sock, Some(limit2)), Some(limit1));
1159        assert_eq!(api.get_unicast_hop_limit(&sock), limit2);
1160        assert_eq!(api.set_unicast_hop_limit(&sock, None), Some(limit2));
1161        assert_eq!(api.get_unicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.unicast);
1162    }
1163
1164    #[ip_test(I)]
1165    fn set_multicast_hop_limit<I: IpExt + DualStackIpExt + TestIpExt>() {
1166        let mut api = new_raw_ip_socket_api::<I>();
1167        let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1168
1169        let limit1 = NonZeroU8::new(1).unwrap();
1170        let limit2 = NonZeroU8::new(2).unwrap();
1171
1172        assert_eq!(api.get_multicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.multicast);
1173        assert_eq!(api.set_multicast_hop_limit(&sock, Some(limit1)), None);
1174        assert_eq!(api.get_multicast_hop_limit(&sock), limit1);
1175        assert_eq!(api.set_multicast_hop_limit(&sock, Some(limit2)), Some(limit1));
1176        assert_eq!(api.get_multicast_hop_limit(&sock), limit2);
1177        assert_eq!(api.set_multicast_hop_limit(&sock, None), Some(limit2));
1178        assert_eq!(api.get_multicast_hop_limit(&sock), DEFAULT_HOP_LIMITS.multicast);
1179    }
1180
1181    #[ip_test(I)]
1182    fn set_multicast_loop<I: IpExt + DualStackIpExt + TestIpExt>() {
1183        let mut api = new_raw_ip_socket_api::<I>();
1184        let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1185
1186        // NB: multicast loopback is enabled by default.
1187        assert_eq!(api.get_multicast_loop(&sock), true);
1188        assert_eq!(api.set_multicast_loop(&sock, false), true);
1189        assert_eq!(api.get_multicast_loop(&sock), false);
1190        assert_eq!(api.set_multicast_loop(&sock, true), false);
1191        assert_eq!(api.get_multicast_loop(&sock), true);
1192    }
1193
1194    #[ip_test(I)]
1195    fn receive_ip_packet<I: IpExt + DualStackIpExt + TestIpExt>() {
1196        let mut api = new_raw_ip_socket_api::<I>();
1197
1198        // Create two sockets with the right protocol, and one socket with the
1199        // wrong protocol.
1200        let proto: I::Proto = IpProto::Udp.into();
1201        let wrong_proto: I::Proto = IpProto::Tcp.into();
1202        let sock1 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1203        let sock2 = api.create(RawIpSocketProtocol::new(proto), Default::default());
1204        let wrong_sock = api.create(RawIpSocketProtocol::new(wrong_proto), Default::default());
1205
1206        // Receive an IP packet with protocol `proto`.
1207        const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1208        let buf = new_ip_packet_buf::<I>(&IP_BODY, proto);
1209        let mut buf_ref = buf.as_ref();
1210        let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1211        {
1212            let (core_ctx, bindings_ctx) = api.ctx.contexts();
1213            core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &DEVICE);
1214        }
1215
1216        // Verify the counters were updated.
1217        assert_eq!(api.core_ctx().state.counters.rx_packets.get(), 2);
1218        assert_eq!(sock1.state().counters().rx_packets.get(), 1);
1219        assert_eq!(sock2.state().counters().rx_packets.get(), 1);
1220        assert_eq!(wrong_sock.state().counters().rx_packets.get(), 0);
1221
1222        let FakeExternalSocketState { received_packets: sock1_packets } =
1223            api.close(sock1).into_removed();
1224        let FakeExternalSocketState { received_packets: sock2_packets } =
1225            api.close(sock2).into_removed();
1226        let FakeExternalSocketState { received_packets: wrong_sock_packets } =
1227            api.close(wrong_sock).into_removed();
1228
1229        // Expect delivery to the two right sockets, but not the wrong socket.
1230        for packets in [sock1_packets, sock2_packets] {
1231            let lock_guard = packets.lock();
1232            let ReceivedIpPacket { data, device } =
1233                assert_matches!(&lock_guard[..], [packet] => packet);
1234            assert_eq!(&data[..], buf.as_ref());
1235            assert_eq!(*device, DEVICE);
1236        }
1237        assert_matches!(&wrong_sock_packets.lock()[..], []);
1238    }
1239
1240    // Verify that sockets created with `RawIpSocketProtocol::Raw` cannot
1241    // receive packets
1242    #[ip_test(I)]
1243    fn cannot_receive_ip_packet_with_proto_raw<I: IpExt + DualStackIpExt + TestIpExt>() {
1244        let mut api = new_raw_ip_socket_api::<I>();
1245        let sock = api.create(RawIpSocketProtocol::Raw, Default::default());
1246
1247        // Try to deliver to an arbitrary proto (UDP), and to the reserved
1248        // proto; neither should be delivered to the socket.
1249        let protocols_to_test = match I::VERSION {
1250            IpVersion::V4 => vec![IpProto::Udp, IpProto::Reserved],
1251            // NB: Don't test `Reserved` with IPv6; the packet will fail to
1252            // parse.
1253            IpVersion::V6 => vec![IpProto::Udp],
1254        };
1255        for proto in protocols_to_test {
1256            let buf = new_ip_packet_buf::<I>(&IP_BODY, proto.into());
1257            let mut buf_ref = buf.as_ref();
1258            let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1259            let (core_ctx, bindings_ctx) = api.ctx.contexts();
1260            core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1261        }
1262
1263        let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1264        assert_matches!(&received_packets.lock()[..], []);
1265    }
1266
1267    #[ip_test(I)]
1268    #[test_case(MultipleDevicesId::A, None, true; "no_bound_device")]
1269    #[test_case(MultipleDevicesId::A, Some(MultipleDevicesId::A), true; "bound_same_device")]
1270    #[test_case(MultipleDevicesId::A, Some(MultipleDevicesId::B), false; "bound_diff_device")]
1271    fn receive_ip_packet_with_bound_device<I: IpExt + DualStackIpExt + TestIpExt>(
1272        send_dev: MultipleDevicesId,
1273        bound_dev: Option<MultipleDevicesId>,
1274        should_deliver: bool,
1275    ) {
1276        const PROTO: IpProto = IpProto::Udp;
1277        let mut api = new_raw_ip_socket_api::<I>();
1278        let sock = api.create(RawIpSocketProtocol::new(PROTO.into()), Default::default());
1279
1280        assert_eq!(api.set_device(&sock, bound_dev.as_ref()), None);
1281
1282        // Deliver an arbitrary packet on `send_dev`.
1283        let buf = new_ip_packet_buf::<I>(&IP_BODY, PROTO.into());
1284        let mut buf_ref = buf.as_ref();
1285        let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1286        {
1287            let (core_ctx, bindings_ctx) = api.ctx.contexts();
1288            core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &send_dev);
1289        }
1290
1291        // Verify the packet was/wasn't received, as expected.
1292        let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1293        if should_deliver {
1294            let lock_guard = received_packets.lock();
1295            let ReceivedIpPacket { data, device } =
1296                assert_matches!(&lock_guard[..], [packet] => packet);
1297            assert_eq!(&data[..], buf.as_ref());
1298            assert_eq!(*device, send_dev);
1299        } else {
1300            assert_matches!(&received_packets.lock()[..], []);
1301        }
1302    }
1303
1304    #[ip_test(I)]
1305    // NB: Don't bother testing for individual ICMP codes. The `filter` sub
1306    // module already covers that extensively.
1307    #[test_case(None, true; "no_filter")]
1308    #[test_case(Some(RawIpSocketIcmpFilter::<I>::ALLOW_ALL), true; "allow_all")]
1309    #[test_case(Some(RawIpSocketIcmpFilter::<I>::DENY_ALL), false; "deny_all")]
1310    fn receive_ip_packet_with_icmp_filter<I: IpExt + DualStackIpExt + TestIpExt>(
1311        filter: Option<RawIpSocketIcmpFilter<I>>,
1312        should_deliver: bool,
1313    ) {
1314        let mut api = new_raw_ip_socket_api::<I>();
1315        let sock = api.create(RawIpSocketProtocol::new(I::ICMP_IP_PROTO), Default::default());
1316
1317        let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1318            assert_eq!(core_ctx.state.counters.rx_icmp_filtered.get(), count);
1319            assert_eq!(sock.state().counters().rx_icmp_filtered.get(), count);
1320        };
1321        assert_counters(api.core_ctx(), 0);
1322
1323        assert_matches!(api.set_icmp_filter(&sock, filter), Ok(None));
1324
1325        // Deliver an arbitrary ICMP message.
1326        let icmp_body = new_icmp_message_buf::<I, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode);
1327        let buf = new_ip_packet_buf::<I>(icmp_body.as_ref(), I::ICMP_IP_PROTO);
1328        let mut buf_ref = buf.as_ref();
1329        let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
1330        {
1331            let (core_ctx, bindings_ctx) = api.ctx.contexts();
1332            core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1333        }
1334
1335        // Verify the packet was/wasn't received, as expected.
1336        assert_counters(api.core_ctx(), should_deliver.then_some(0).unwrap_or(1));
1337        let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1338        if should_deliver {
1339            let lock_guard = received_packets.lock();
1340            let ReceivedIpPacket { data, device: _ } =
1341                assert_matches!(&lock_guard[..], [packet] => packet);
1342            assert_eq!(&data[..], buf.as_ref());
1343        } else {
1344            assert_matches!(&received_packets.lock()[..], []);
1345        }
1346    }
1347
1348    // Verify that ICMPv6 messages with an invalid checksum won't be received.
1349    // Note that the successful delivery case is tested by
1350    // `receive_ip_packet_with_icmp_filter`.
1351    #[test]
1352    fn do_not_receive_icmpv6_packet_with_bad_checksum() {
1353        let mut api = new_raw_ip_socket_api::<Ipv6>();
1354        let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1355
1356        let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1357            assert_eq!(core_ctx.state.counters.rx_checksum_errors.get(), count);
1358            assert_eq!(sock.state().counters().rx_checksum_errors.get(), count);
1359        };
1360        assert_counters(api.core_ctx(), 0);
1361
1362        // Use a valid ICMP message, but intentionally corrupt the checksum.
1363        // The checksum is present at bytes 2 & 3.
1364        let mut icmp_body = new_icmp_message_buf::<Ipv6, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode)
1365            .as_ref()
1366            .to_vec();
1367        const CORRUPT_CHECKSUM: [u8; 2] = [123, 234];
1368        assert_ne!(
1369            packet_formats::testutil::overwrite_icmpv6_checksum(
1370                icmp_body.as_mut(),
1371                CORRUPT_CHECKSUM
1372            )
1373            .expect("parse should succeed"),
1374            CORRUPT_CHECKSUM
1375        );
1376
1377        let buf = new_ip_packet_buf::<Ipv6>(icmp_body.as_ref(), Ipv6Proto::Icmpv6);
1378        let mut buf_ref = buf.as_ref();
1379        let packet = buf_ref.parse::<Ipv6Packet<_>>().expect("parse should succeed");
1380        {
1381            let (core_ctx, bindings_ctx) = api.ctx.contexts();
1382            core_ctx.deliver_packet_to_raw_ip_sockets(bindings_ctx, &packet, &MultipleDevicesId::A);
1383        }
1384
1385        // Verify the packet wasn't received.
1386        assert_counters(api.core_ctx(), 1);
1387        let FakeExternalSocketState { received_packets } = api.close(sock).into_removed();
1388        assert_matches!(&received_packets.lock()[..], []);
1389    }
1390
1391    #[ip_test(I)]
1392    #[test_case(None, None; "default_send")]
1393    #[test_case(Some(MultipleDevicesId::A), None; "with_bound_dev")]
1394    #[test_case(None, Some(123); "with_hop_limit")]
1395    fn send_to<I: IpExt + DualStackIpExt + TestIpExt>(
1396        bound_dev: Option<MultipleDevicesId>,
1397        hop_limit: Option<u8>,
1398    ) {
1399        const PROTO: IpProto = IpProto::Udp;
1400        let mut api = new_raw_ip_socket_api::<I>();
1401        let sock = api.create(RawIpSocketProtocol::new(PROTO.into()), Default::default());
1402
1403        let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1404            assert_eq!(core_ctx.state.counters.tx_packets.get(), count);
1405            assert_eq!(sock.state().counters().tx_packets.get(), count);
1406        };
1407        assert_counters(api.core_ctx(), 0);
1408
1409        assert_eq!(api.set_device(&sock, bound_dev.as_ref()), None);
1410        let hop_limit = hop_limit.and_then(NonZeroU8::new);
1411        assert_eq!(api.set_unicast_hop_limit(&sock, hop_limit), None);
1412
1413        let remote_ip = ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip);
1414        assert_matches!(&api.ctx.core_ctx().take_frames()[..], []);
1415        api.send_to(&sock, Some(remote_ip), Buf::new(IP_BODY.to_vec(), ..))
1416            .expect("send should succeed");
1417        let frames = api.core_ctx().take_frames();
1418        let (SendIpPacketMeta { device, src_ip, dst_ip, proto, mtu, ttl, .. }, data) =
1419            assert_matches!( &frames[..], [packet] => packet);
1420        assert_eq!(&data[..], &IP_BODY[..]);
1421        assert_eq!(*dst_ip, remote_ip.addr());
1422        assert_eq!(*src_ip, I::TEST_ADDRS.local_ip);
1423        if let Some(bound_dev) = bound_dev {
1424            assert_eq!(*device, bound_dev);
1425        }
1426        assert_eq!(*proto, <I as IpProtoExt>::Proto::from(PROTO));
1427        assert_eq!(*mtu, Mtu::max());
1428        assert_eq!(*ttl, hop_limit);
1429
1430        assert_counters(api.core_ctx(), 1);
1431    }
1432
1433    #[ip_test(I)]
1434    fn send_to_disallows_raw_protocol<I: IpExt + DualStackIpExt + TestIpExt>() {
1435        let mut api = new_raw_ip_socket_api::<I>();
1436        let sock = api.create(RawIpSocketProtocol::Raw, Default::default());
1437        assert_matches!(
1438            api.send_to(&sock, None, Buf::new(IP_BODY.to_vec(), ..)),
1439            Err(RawIpSocketSendToError::ProtocolRaw)
1440        );
1441    }
1442
1443    #[test]
1444    fn send_to_disallows_dualstack() {
1445        let mut api = new_raw_ip_socket_api::<Ipv6>();
1446        let sock = api.create(RawIpSocketProtocol::new(IpProto::Udp.into()), Default::default());
1447        let mapped_remote_ip = ZonedAddr::Unzoned(Ipv4::TEST_ADDRS.local_ip.to_ipv6_mapped());
1448        assert_matches!(
1449            api.send_to(&sock, Some(mapped_remote_ip), Buf::new(IP_BODY.to_vec(), ..)),
1450            Err(RawIpSocketSendToError::MappedRemoteIp)
1451        );
1452    }
1453
1454    // Verify that packets sent on ICMPv6 raw IP sockets have their checksum
1455    // automatically populated by the netstack.
1456    #[test]
1457    fn icmpv6_send_to_generates_checksum() {
1458        let mut api = new_raw_ip_socket_api::<Ipv6>();
1459        let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1460
1461        // Use a valid ICMP body, but intentionally corrupt the checksum.
1462        // The checksum is present at bytes 2 & 3.
1463        let icmp_body_with_checksum =
1464            new_icmp_message_buf::<Ipv6, _>(IcmpEchoReply::new(0, 0), IcmpZeroCode)
1465                .as_ref()
1466                .to_vec();
1467        const CORRUPT_CHECKSUM: [u8; 2] = [123, 234];
1468        let mut icmp_body_without_checksum = icmp_body_with_checksum.clone();
1469        assert_ne!(
1470            packet_formats::testutil::overwrite_icmpv6_checksum(
1471                icmp_body_without_checksum.as_mut(),
1472                CORRUPT_CHECKSUM,
1473            )
1474            .expect("parse should succeed"),
1475            CORRUPT_CHECKSUM
1476        );
1477
1478        // Send the buffer that has the corrupt checksum.
1479        let remote_ip = ZonedAddr::Unzoned(Ipv6::TEST_ADDRS.remote_ip);
1480        assert_matches!(&api.ctx.core_ctx().take_frames()[..], []);
1481        api.send_to(&sock, Some(remote_ip), Buf::new(icmp_body_without_checksum.to_vec(), ..))
1482            .expect("send should succeed");
1483
1484        // Observe that the checksum is populated.
1485        let frames = api.core_ctx().take_frames();
1486        let (_send_ip_packet_meta, data) = assert_matches!( &frames[..], [packet] => packet);
1487        assert_eq!(&data[..], icmp_body_with_checksum);
1488    }
1489
1490    // Verify that invalid ICMPv6 packets cannot be sent on raw IP sockets.
1491    #[test_case(Icmpv6MessageType::DestUnreachable.into(), 4; "header-too-short")]
1492    #[test_case(0, 8; "message-type-zero-not-supported")]
1493    fn icmpv6_send_to_invalid_body(message_type: u8, header_len: usize) {
1494        let mut api = new_raw_ip_socket_api::<Ipv6>();
1495        let sock = api.create(RawIpSocketProtocol::new(Ipv6Proto::Icmpv6), Default::default());
1496
1497        let assert_counters = |core_ctx: &mut FakeCoreCtx<_, _>, count: u64| {
1498            assert_eq!(core_ctx.state.counters.tx_checksum_errors.get(), count);
1499            assert_eq!(sock.state().counters().tx_checksum_errors.get(), count);
1500        };
1501
1502        let mut body = vec![0; header_len];
1503        body[0] = message_type;
1504        assert_counters(api.core_ctx(), 0);
1505
1506        let remote_ip = ZonedAddr::Unzoned(Ipv6::TEST_ADDRS.remote_ip);
1507        assert_matches!(
1508            api.send_to(&sock, Some(remote_ip), Buf::new(body, ..)),
1509            Err(RawIpSocketSendToError::InvalidBody)
1510        );
1511
1512        assert_counters(api.core_ctx(), 1);
1513    }
1514
1515    #[ip_test(I)]
1516    #[test_case::test_matrix(
1517        [MarkDomain::Mark1, MarkDomain::Mark2],
1518        [None, Some(0), Some(1)]
1519    )]
1520    fn raw_ip_socket_marks<I: TestIpExt + DualStackIpExt + IpExt>(
1521        domain: MarkDomain,
1522        mark: Option<u32>,
1523    ) {
1524        let mut api = new_raw_ip_socket_api::<I>();
1525        let socket = api.create(RawIpSocketProtocol::Raw, Default::default());
1526
1527        // Doesn't have a mark by default.
1528        assert_eq!(api.get_mark(&socket, domain), Mark(None));
1529
1530        let mark = Mark(mark);
1531        // We can set and get back the mark.
1532        api.set_mark(&socket, domain, mark);
1533        assert_eq!(api.get_mark(&socket, domain), mark);
1534    }
1535}