Skip to main content

netstack3_base/socket/
base.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! General-purpose socket utilities common to device layer and IP layer
6//! sockets.
7
8use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersion, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17    AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19use thiserror::Error;
20
21use crate::LocalAddressError;
22use crate::data_structures::socketmap::{
23    Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
24};
25use crate::device::{
26    DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
27};
28use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
29use crate::ip::BroadcastIpExt;
30use crate::socket::SocketCookie;
31use crate::socket::address::{
32    AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
33};
34use packet_formats::ip::{IpProto, Ipv4Proto, Ipv6Proto};
35
36/// A dual stack IP extention trait that provides the `OtherVersion` associated
37/// type.
38pub trait DualStackIpExt: Ip {
39    /// The "other" IP version, e.g. [`Ipv4`] for [`Ipv6`] and vice-versa.
40    type OtherVersion: DualStackIpExt<OtherVersion = Self>;
41}
42
43impl DualStackIpExt for Ipv4 {
44    type OtherVersion = Ipv6;
45}
46
47impl DualStackIpExt for Ipv6 {
48    type OtherVersion = Ipv4;
49}
50
51/// A tuple of values for `T` for both `I` and `I::OtherVersion`.
52pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
53    this_stack: <T as GenericOverIp<I>>::Type,
54    other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
55    _marker: IpVersionMarker<I>,
56}
57
58impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
59    /// Creates a new tuple with `this_stack` and `other_stack` values.
60    pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
61    where
62        T: GenericOverIp<I, Type = T>,
63    {
64        Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
65    }
66
67    /// Retrieves `(this_stack, other_stack)` from the tuple.
68    pub fn into_inner(
69        self,
70    ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
71        let Self { this_stack, other_stack, _marker } = self;
72        (this_stack, other_stack)
73    }
74
75    /// Retrieves `this_stack` from the tuple.
76    pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
77        self.this_stack
78    }
79
80    /// Borrows `this_stack` from the tuple.
81    pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
82        &self.this_stack
83    }
84
85    /// Retrieves `other_stack` from the tuple.
86    pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
87        self.other_stack
88    }
89
90    /// Borrows `other_stack` from the tuple.
91    pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
92        &self.other_stack
93    }
94
95    /// Flips the types, making `this_stack` `other_stack` and vice-versa.
96    pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
97        let Self { this_stack, other_stack, _marker } = self;
98        DualStackTuple {
99            this_stack: other_stack,
100            other_stack: this_stack,
101            _marker: IpVersionMarker::new(),
102        }
103    }
104
105    /// Casts to IP version `X`.
106    ///
107    /// Given `DualStackTuple` contains complete information for both IP
108    /// versions, it can be easily cast into an arbitrary `X` IP version.
109    ///
110    /// This can be used to tie together type parameters when dealing with dual
111    /// stack sockets. For example, a `DualStackTuple` defined for `SockI` can
112    /// be cast to any `WireI`.
113    pub fn cast<X>(self) -> DualStackTuple<X, T>
114    where
115        X: DualStackIpExt,
116        T: GenericOverIp<X>
117            + GenericOverIp<X::OtherVersion>
118            + GenericOverIp<Ipv4>
119            + GenericOverIp<Ipv6>,
120    {
121        I::map_ip_in(
122            self,
123            |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
124            |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
125        )
126    }
127}
128
129impl<
130    I: DualStackIpExt,
131    NewIp: DualStackIpExt,
132    T: GenericOverIp<NewIp>
133        + GenericOverIp<NewIp::OtherVersion>
134        + GenericOverIp<I>
135        + GenericOverIp<I::OtherVersion>,
136> GenericOverIp<NewIp> for DualStackTuple<I, T>
137{
138    type Type = DualStackTuple<NewIp, T>;
139}
140
141/// Extension trait for `Ip` providing socket-specific functionality.
142pub trait SocketIpExt: Ip {
143    /// `Self::LOOPBACK_ADDRESS`, but wrapped in the `SocketIpAddr` type.
144    const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
145        // SAFETY: The loopback address is a valid SocketIpAddr, as verified
146        // in the `loopback_addr_is_valid_socket_addr` test.
147        SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
148    };
149}
150
151impl<I: Ip> SocketIpExt for I {}
152
153#[cfg(test)]
154mod socket_ip_ext_test {
155    use super::*;
156    use ip_test_macro::ip_test;
157
158    #[ip_test(I)]
159    fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
160        // `LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR is defined with the "unchecked"
161        // constructor (which supports const construction). Verify here that the
162        // addr actually satisfies all the requirements (protecting against far
163        // away changes)
164        let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
165            .expect("loopback address should be a valid SocketIpAddr");
166    }
167}
168
169/// An IP version-specific protocol.
170#[derive(Copy, Clone, Debug, PartialEq, Eq, GenericOverIp)]
171#[generic_over_ip()]
172pub enum EitherIpProto {
173    /// An IPv4 protocol.
174    V4(Ipv4Proto),
175    /// An IPv6 protocol.
176    V6(Ipv6Proto),
177}
178
179impl EitherIpProto {
180    /// Returns the IP version of the protocol.
181    pub fn ip_version(&self) -> IpVersion {
182        match self {
183            Self::V4(_) => IpVersion::V4,
184            Self::V6(_) => IpVersion::V6,
185        }
186    }
187
188    /// Returns the transport protocol if it is standard.
189    pub fn ip_proto(&self) -> Option<IpProto> {
190        match self {
191            Self::V4(p) => match p {
192                Ipv4Proto::Proto(proto) => Some(*proto),
193                _ => None,
194            },
195            Self::V6(p) => match p {
196                Ipv6Proto::Proto(proto) => Some(*proto),
197                _ => None,
198            },
199        }
200    }
201
202    /// Returns the raw protocol number as a u8.
203    pub fn u8_value(&self) -> u8 {
204        match self {
205            Self::V4(p) => (*p).into(),
206            Self::V6(p) => (*p).into(),
207        }
208    }
209}
210
211/// Information about a socket passed to a socket operations filter.
212#[derive(Clone, Debug)]
213#[cfg_attr(any(test, feature = "testutils"), derive(PartialEq, Eq))]
214pub struct SocketInfo {
215    /// The IP-version-specific transport protocol.
216    pub proto: EitherIpProto,
217    /// The socket cookie.
218    pub cookie: SocketCookie,
219}
220
221/// State belonging to either IP stack.
222///
223/// Like `[either::Either]`, but with more helpful variant names.
224///
225/// Note that this type is not optimally type-safe, because `T` and `O` are not
226/// bound by `IP` and `IP::OtherVersion`, respectively. In many cases it may be
227/// more appropriate to define a one-off enum parameterized over `I: Ip`.
228#[derive(Debug, PartialEq, Eq)]
229pub enum EitherStack<T, O> {
230    /// In the current stack version.
231    ThisStack(T),
232    /// In the other version of the stack.
233    OtherStack(O),
234}
235
236impl<T, O> Clone for EitherStack<T, O>
237where
238    T: Clone,
239    O: Clone,
240{
241    #[cfg_attr(feature = "instrumented", track_caller)]
242    fn clone(&self) -> Self {
243        match self {
244            Self::ThisStack(t) => Self::ThisStack(t.clone()),
245            Self::OtherStack(t) => Self::OtherStack(t.clone()),
246        }
247    }
248}
249
250/// Control flow type containing either a dual-stack or non-dual-stack context.
251///
252/// This type exists to provide nice names to the result of
253/// [`BoundStateContext::dual_stack_context`], and to allow generic code to
254/// match on when checking whether a socket protocol and IP version support
255/// dual-stack operation. If dual-stack operation is supported, a
256/// [`MaybeDualStack::DualStack`] value will be held, otherwise a `NonDualStack`
257/// value.
258///
259/// Note that the templated types to not have trait bounds; those are provided
260/// by the trait with the `dual_stack_context` function.
261///
262/// In monomorphized code, this type frequently has exactly one template
263/// parameter that is uninstantiable (it contains an instance of
264/// [`core::convert::Infallible`] or some other empty enum, or a reference to
265/// the same)! That lets the compiler optimize it out completely, creating no
266/// actual runtime overhead.
267#[derive(Debug)]
268#[allow(missing_docs)]
269pub enum MaybeDualStack<DS, NDS> {
270    DualStack(DS),
271    NotDualStack(NDS),
272}
273
274// Implement `GenericOverIp` for a `MaybeDualStack` whose `DS` and `NDS` also
275// implement `GenericOverIp`.
276impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
277    for MaybeDualStack<DS, NDS>
278{
279    type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
280}
281
282/// An error encountered while enabling or disabling dual-stack operation.
283#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
284#[generic_over_ip()]
285pub enum SetDualStackEnabledError {
286    /// A socket can only have dual stack enabled or disabled while unbound.
287    #[error("a socket can only have dual stack enabled or disabled while unbound")]
288    SocketIsBound,
289    /// The socket's protocol is not dual stack capable.
290    #[error(transparent)]
291    NotCapable(#[from] NotDualStackCapableError),
292}
293
294/// An error encountered when attempting to perform dual stack operations on
295/// socket with a non dual stack capable protocol.
296#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
297#[generic_over_ip()]
298#[error("socket's protocol is not dual-stack capable")]
299pub struct NotDualStackCapableError;
300
301/// Describes which direction(s) of the data path should be shut down.
302#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
303pub struct Shutdown {
304    /// True if the send path is shut down for the owning socket.
305    ///
306    /// If this is true, the socket should not be able to send packets.
307    pub send: bool,
308    /// True if the receive path is shut down for the owning socket.
309    ///
310    /// If this is true, the socket should not be able to receive packets.
311    pub receive: bool,
312}
313
314/// Which direction(s) to shut down for a socket.
315#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
316#[generic_over_ip()]
317pub enum ShutdownType {
318    /// Prevent sending packets on the socket.
319    Send,
320    /// Prevent receiving packets on the socket.
321    Receive,
322    /// Prevent sending and receiving packets on the socket.
323    SendAndReceive,
324}
325
326impl ShutdownType {
327    /// Returns a tuple of booleans for `(shutdown_send, shutdown_receive)`.
328    pub fn to_send_receive(&self) -> (bool, bool) {
329        match self {
330            Self::Send => (true, false),
331            Self::Receive => (false, true),
332            Self::SendAndReceive => (true, true),
333        }
334    }
335
336    /// Creates a [`ShutdownType`] from a pair of bools for send and receive.
337    pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
338        match (send, receive) {
339            (true, false) => Some(Self::Send),
340            (false, true) => Some(Self::Receive),
341            (true, true) => Some(Self::SendAndReceive),
342            (false, false) => None,
343        }
344    }
345}
346
347/// Extensions to IP Address witnesses useful in the context of sockets.
348pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
349    /// Determines whether the provided address is underspecified by itself.
350    ///
351    /// Some addresses are ambiguous and so must have a zone identifier in order
352    /// to be used in a socket address. This function returns true for IPv6
353    /// link-local addresses and false for all others.
354    fn must_have_zone(&self) -> bool
355    where
356        Self: Copy,
357    {
358        self.try_into_null_zoned().is_some()
359    }
360
361    /// Converts into a [`AddrAndZone<A, ()>`] if the address requires a zone.
362    ///
363    /// Otherwise returns `None`.
364    fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
365        if self.get().is_loopback() {
366            return None;
367        }
368        AddrAndZone::new(self, ())
369    }
370}
371
372impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
373
374/// An extention trait for [`ZonedAddr`].
375pub trait SocketZonedAddrExt<W, A, D> {
376    /// Returns the address and device that should be used for a socket.
377    ///
378    /// Given an address for a socket and an optional device that the socket is
379    /// already bound on, returns the address and device that should be used
380    /// for the socket. If `addr` and `device` require inconsistent devices,
381    /// or if `addr` requires a zone but there is none specified (by `addr` or
382    /// `device`), an error is returned.
383    fn resolve_addr_with_device(
384        self,
385        device: Option<D::Weak>,
386    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
387    where
388        D: StrongDeviceIdentifier;
389}
390
391impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
392where
393    W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
394    A: IpAddress,
395{
396    fn resolve_addr_with_device(
397        self,
398        device: Option<D::Weak>,
399    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
400    where
401        D: StrongDeviceIdentifier,
402    {
403        let (addr, zone) = self.into_addr_zone();
404        let device = match (zone, device) {
405            (Some(zone), Some(device)) => {
406                if device != zone {
407                    return Err(ZonedAddressError::DeviceZoneMismatch);
408                }
409                Some(EitherDeviceId::Strong(zone))
410            }
411            (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
412            (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
413            (None, None) => {
414                if addr.as_ref().must_have_zone() {
415                    return Err(ZonedAddressError::RequiredZoneNotProvided);
416                } else {
417                    None
418                }
419            }
420        };
421        Ok((addr, device))
422    }
423}
424
425/// A helper type to verify if applying socket updates is allowed for a given
426/// current state.
427///
428/// The fields in `SocketDeviceUpdate` define the current state,
429/// [`SocketDeviceUpdate::try_update`] applies the verification logic.
430pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
431    /// The current local IP address.
432    pub local_ip: Option<&'a SpecifiedAddr<A>>,
433    /// The current remote IP address.
434    pub remote_ip: Option<&'a SpecifiedAddr<A>>,
435    /// The currently bound device.
436    pub old_device: Option<&'a D>,
437}
438
439impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
440    /// Checks if an update from `old_device` to `new_device` is allowed,
441    /// returning an error if not.
442    pub fn check_update<N>(
443        self,
444        new_device: Option<&N>,
445    ) -> Result<(), SocketDeviceUpdateNotAllowedError>
446    where
447        D: PartialEq<N>,
448    {
449        let Self { local_ip, remote_ip, old_device } = self;
450        let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
451            || remote_ip.is_some_and(|a| a.must_have_zone());
452
453        if !must_have_zone {
454            return Ok(());
455        }
456
457        let old_device = old_device.unwrap_or_else(|| {
458            panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
459        });
460
461        if new_device.is_some_and(|new_device| old_device == new_device) {
462            Ok(())
463        } else {
464            Err(SocketDeviceUpdateNotAllowedError)
465        }
466    }
467}
468
469/// The device can't be updated on a socket.
470pub struct SocketDeviceUpdateNotAllowedError;
471
472/// Specification for the identifiers in an [`AddrVec`].
473///
474/// This is a convenience trait for bundling together the local and remote
475/// identifiers for a protocol.
476pub trait SocketMapAddrSpec {
477    /// The local identifier portion of a socket address.
478    type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
479    /// The remote identifier portion of a socket address.
480    type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
481}
482
483/// Information about the address in a [`ListenerAddr`].
484pub struct ListenerAddrInfo {
485    /// Whether the address has a device bound.
486    pub has_device: bool,
487    /// Whether the listener is on a specified address (as opposed to a blanket
488    /// listener).
489    pub specified_addr: bool,
490}
491
492impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
493    pub(crate) fn info(&self) -> ListenerAddrInfo {
494        let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
495        ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
496    }
497}
498
499/// Specifies the types parameters for [`BoundSocketMap`] state as a single bundle.
500pub trait SocketMapStateSpec {
501    /// The tag value of a socket address vector entry.
502    ///
503    /// These values are derived from [`Self::ListenerAddrState`] and
504    /// [`Self::ConnAddrState`].
505    type AddrVecTag: Eq + Copy + Debug + 'static;
506
507    /// Returns a the tag for a listener in the socket map.
508    fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
509
510    /// Returns a the tag for a connected socket in the socket map.
511    fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
512
513    /// An identifier for a listening socket.
514    type ListenerId: Clone + Debug;
515    /// An identifier for a connected socket.
516    type ConnId: Clone + Debug;
517
518    /// The state stored for a listening socket that is used to determine
519    /// whether sockets can share an address.
520    type ListenerSharingState: Clone + Debug;
521
522    /// The state stored for a connected socket that is used to determine
523    /// whether sockets can share an address.
524    type ConnSharingState: Clone + Debug;
525
526    /// The state stored for a listener socket address.
527    type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
528        + Debug;
529
530    /// The state stored for a connected socket address.
531    type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
532        + Debug;
533}
534
535/// Error returned by implementations of [`SocketMapAddrStateSpec`] to indicate
536/// incompatible changes to a socket map.
537#[derive(Copy, Clone, Debug, Eq, PartialEq)]
538pub struct IncompatibleError;
539
540/// An inserter into a [`SocketMap`].
541pub trait Inserter<T> {
542    /// Inserts the provided item and consumes `self`.
543    ///
544    /// Inserts a single item and consumes the inserter (thus preventing
545    /// additional insertions).
546    fn insert(self, item: T);
547}
548
549impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
550    fn insert(self, item: T) {
551        self.extend([item])
552    }
553}
554
555impl<T> Inserter<T> for Never {
556    fn insert(self, _: T) {
557        match self {}
558    }
559}
560
561/// Describes an entry in a [`SocketMap`] for a listener or connection address.
562pub trait SocketMapAddrStateSpec {
563    /// The type of ID that can be present at the address.
564    type Id;
565
566    /// The sharing state for the address.
567    ///
568    /// This can be used to determine whether a socket can be inserted at the
569    /// address. Every socket has its own sharing state associated with it,
570    /// though the sharing state is not necessarily stored in the address
571    /// entry.
572    type SharingState;
573
574    /// The type of inserter returned by [`SocketMapAddrStateSpec::try_get_inserter`].
575    type Inserter<'a>: Inserter<Self::Id> + 'a
576    where
577        Self: 'a,
578        Self::Id: 'a;
579
580    /// Creates a new `Self` holding the provided socket with the given new
581    /// sharing state at the specified address.
582    fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
583
584    /// Looks up the ID in self, returning `true` if it is present.
585    fn contains_id(&self, id: &Self::Id) -> bool;
586
587    /// Enables insertion in `self` for a new socket with the provided sharing
588    /// state.
589    ///
590    /// If the new state is incompatible with the existing socket(s),
591    /// implementations of this function should return `Err(IncompatibleError)`.
592    /// If `Ok(x)` is returned, calling `x.insert(y)` will insert `y` into
593    /// `self`.
594    fn try_get_inserter<'a, 'b>(
595        &'b mut self,
596        new_sharing_state: &'a Self::SharingState,
597    ) -> Result<Self::Inserter<'b>, IncompatibleError>;
598
599    /// Returns `Ok` if an entry with the given sharing state could be added
600    /// to `self`.
601    ///
602    /// If this returns `Ok`, `try_get_dest` should succeed.
603    fn could_insert(&self, new_sharing_state: &Self::SharingState)
604    -> Result<(), IncompatibleError>;
605
606    /// Removes the given socket from the existing state.
607    ///
608    /// Implementations should assume that `id` is contained in `self`.
609    fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
610}
611
612/// Provides behavior on updating the sharing state of a [`SocketMap`] entry.
613pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
614    /// Attempts to update the sharing state of the address state with id `id`
615    /// to `new_sharing_state`.
616    fn try_update_sharing(
617        &mut self,
618        id: Self::Id,
619        new_sharing_state: &Self::SharingState,
620    ) -> Result<(), IncompatibleError>;
621}
622
623/// Provides conflict detection for a [`SocketMapStateSpec`].
624pub trait SocketMapConflictPolicy<
625    Addr,
626    SharingState,
627    I: Ip,
628    D: DeviceIdentifier,
629    A: SocketMapAddrSpec,
630>: SocketMapStateSpec
631{
632    /// Checks whether a new socket with the provided state can be inserted at
633    /// the given address in the existing socket map, returning an error
634    /// otherwise.
635    ///
636    /// Implementations of this function should check for any potential
637    /// conflicts that would arise when inserting a socket with state
638    /// `new_sharing_state` into a new or existing entry at `addr` in
639    /// `socketmap`.
640    fn check_insert_conflicts(
641        new_sharing_state: &SharingState,
642        addr: &Addr,
643        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
644    ) -> Result<(), InsertError>;
645}
646
647/// Defines the policy for updating the sharing state of entries in the
648/// [`SocketMap`].
649pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
650    SocketMapConflictPolicy<Addr, SharingState, I, D, A>
651where
652    A: SocketMapAddrSpec,
653{
654    /// Returns whether the entry `addr` in `socketmap` allows the sharing state
655    /// to transition from `old_sharing` to `new_sharing`.
656    fn allows_sharing_update(
657        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
658        addr: &Addr,
659        old_sharing: &SharingState,
660        new_sharing: &SharingState,
661    ) -> Result<(), UpdateSharingError>;
662}
663
664/// A bound socket state that is either a listener or a connection.
665#[derive(Derivative)]
666#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
667#[allow(missing_docs)]
668pub enum Bound<S: SocketMapStateSpec + ?Sized> {
669    Listen(S::ListenerAddrState),
670    Conn(S::ConnAddrState),
671}
672
673/// An "address vector" type that can hold any address in a [`SocketMap`].
674///
675/// This is a "vector" in the mathematical sense, in that it denotes an address
676/// in a space. Here, the space is the possible addresses to which a socket
677/// receiving IP packets can be bound.
678///
679/// `AddrVec`s are used as keys for the `SocketMap` type. Since an incoming
680/// packet can match more than one address, for each incoming packet there is a
681/// set of possible `AddrVec` keys whose entries (sockets) in a `SocketMap`
682/// might receive the packet.
683///
684/// This set of keys can be ordered by precedence as described in the
685/// documentation for [`AddrVecIter`]. Calling [`IterShadows::iter_shadows`] on
686/// an instance will produce the sequence of addresses it has precedence over.
687#[derive(Derivative)]
688#[derivative(
689    Debug(bound = "D: Debug"),
690    Clone(bound = "D: Clone"),
691    Eq(bound = "D: Eq"),
692    PartialEq(bound = "D: PartialEq"),
693    Hash(bound = "D: Hash")
694)]
695#[allow(missing_docs)]
696pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
697    Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
698    Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
699}
700
701impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
702    Tagged<AddrVec<I, D, A>> for Bound<S>
703{
704    type Tag = S::AddrVecTag;
705    fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
706        match (self, address) {
707            (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
708            (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
709                S::connected_tag(device.is_some(), c)
710            }
711            (Bound::Listen(_), AddrVec::Conn(_)) => {
712                unreachable!("found listen state for conn addr")
713            }
714            (Bound::Conn(_), AddrVec::Listen(_)) => {
715                unreachable!("found conn state for listen addr")
716            }
717        }
718    }
719}
720
721impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
722    type IterShadows = AddrVecIter<I, D, A>;
723
724    fn iter_shadows(&self) -> Self::IterShadows {
725        let (socket_ip_addr, device) = match self.clone() {
726            AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
727            AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
728        };
729        let mut iter = match device {
730            Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
731            None => AddrVecIter::without_device(socket_ip_addr),
732        };
733        // Skip the first element, which is always `*self`.
734        assert_eq!(iter.next().as_ref(), Some(self));
735        iter
736    }
737}
738
739/// How a socket is bound on the system.
740#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
741#[allow(missing_docs)]
742pub enum SocketAddrType {
743    AnyListener,
744    SpecificListener,
745    Connected,
746}
747
748impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
749    fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
750        match addr {
751            Some(_) => SocketAddrType::SpecificListener,
752            None => SocketAddrType::AnyListener,
753        }
754    }
755}
756
757impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
758    fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
759        SocketAddrType::Connected
760    }
761}
762
763/// The result of attempting to remove a socket from a collection of sockets.
764pub enum RemoveResult {
765    /// The value was removed successfully.
766    Success,
767    /// The value is the last value in the collection so the entire collection
768    /// should be removed.
769    IsLast,
770}
771
772#[derive(Derivative)]
773#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
774pub enum SocketId<S: SocketMapStateSpec> {
775    Listener(S::ListenerId),
776    Connection(S::ConnId),
777}
778
779/// A map from socket addresses to sockets.
780///
781/// The types of keys and IDs is determined by the [`SocketMapStateSpec`]
782/// parameter. Each listener and connected socket stores additional state.
783/// Listener and connected sockets are keyed independently, but share the same
784/// address vector space. Conflicts are detected on attempted insertion of new
785/// sockets.
786///
787/// Listener addresses map to listener-address-specific state, and likewise
788/// with connected addresses. Depending on protocol (determined by the
789/// `SocketMapStateSpec` protocol), these address states can hold one or more
790/// socket identifiers (e.g. UDP sockets with `SO_REUSEPORT` set can share an
791/// address).
792#[derive(Derivative)]
793#[derivative(Default(bound = ""))]
794pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
795    addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
796}
797
798impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
799    BoundSocketMap<I, D, A, S>
800{
801    /// Returns the number of entries in the map.
802    pub fn len(&self) -> usize {
803        self.addr_to_state.len()
804    }
805}
806
807/// Uninstantiable tag type for denoting listening sockets.
808pub enum Listener {}
809/// Uninstantiable tag type for denoting connected sockets.
810pub enum Connection {}
811
812/// View struct over one type of sockets in a [`BoundSocketMap`].
813pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
814
815impl<
816    'a,
817    I: Ip,
818    D: DeviceIdentifier,
819    SocketType: ConvertSocketMapState<I, D, A, S>,
820    A: SocketMapAddrSpec,
821    S: SocketMapStateSpec,
822> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
823where
824    S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
825{
826    /// Returns the state at an address, if there is any.
827    pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
828        let Self(addr_to_state, _marker) = self;
829        addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
830            SocketType::from_bound_ref(state)
831                .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
832        })
833    }
834
835    /// Returns `Ok(())` if a socket could be inserted, otherwise an error.
836    ///
837    /// Goes through a dry run of inserting a socket at the given address and
838    /// with the given sharing state, returning `Ok(())` if the insertion would
839    /// succeed, otherwise the error that would be returned.
840    pub fn could_insert(
841        self,
842        addr: &SocketType::Addr,
843        sharing: &SocketType::SharingState,
844    ) -> Result<(), InsertError> {
845        let Self(addr_to_state, _) = self;
846        match self.get_by_addr(addr) {
847            Some(state) => {
848                state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
849            }
850            None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
851        }
852    }
853}
854
855/// A borrowed state entry in a [`SocketMap`].
856#[derive(Derivative)]
857#[derivative(Debug(bound = ""))]
858pub struct SocketStateEntry<
859    'a,
860    I: Ip,
861    D: DeviceIdentifier,
862    A: SocketMapAddrSpec,
863    S: SocketMapStateSpec,
864    SocketType,
865> {
866    id: SocketId<S>,
867    addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
868    _marker: PhantomData<SocketType>,
869}
870
871impl<
872    'a,
873    I: Ip,
874    D: DeviceIdentifier,
875    SocketType: ConvertSocketMapState<I, D, A, S>,
876    A: SocketMapAddrSpec,
877    S: SocketMapStateSpec
878        + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
879> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
880where
881    SocketType::SharingState: Clone,
882    SocketType::Id: Clone,
883{
884    /// Attempts to insert a new entry into the [`SocketMap`] backing this
885    /// `Sockets`.
886    pub fn try_insert(
887        self,
888        socket_addr: SocketType::Addr,
889        tag_state: SocketType::SharingState,
890        id: SocketType::Id,
891    ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, InsertError> {
892        self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
893            .map(|(entry, ())| entry)
894    }
895
896    /// Like [`Sockets::try_insert`] but calls `make_id` to create a socket ID
897    /// before inserting into the map.
898    ///
899    /// `make_id` returns type `R` that is returned to the caller on success.
900    pub fn try_insert_with<R>(
901        self,
902        socket_addr: SocketType::Addr,
903        tag_state: SocketType::SharingState,
904        make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
905    ) -> Result<(SocketStateEntry<'a, I, D, A, S, SocketType>, R), InsertError> {
906        let Self(addr_to_state, _) = self;
907        S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state)?;
908
909        let addr = SocketType::to_addr_vec(&socket_addr);
910
911        match addr_to_state.entry(addr) {
912            Entry::Occupied(mut o) => {
913                let (id, ret) = o.map_mut(|bound| {
914                    let bound = match SocketType::from_bound_mut(bound) {
915                        Some(bound) => bound,
916                        None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
917                    };
918                    match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
919                        bound, &tag_state,
920                    ) {
921                        Ok(v) => {
922                            let (id, ret) = make_id(socket_addr, tag_state);
923                            v.insert(id.clone());
924                            Ok((SocketType::to_socket_id(id), ret))
925                        }
926                        Err(IncompatibleError) => Err(InsertError::Exists),
927                    }
928                })?;
929                Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
930            }
931            Entry::Vacant(v) => {
932                let (id, ret) = make_id(socket_addr, tag_state.clone());
933                let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
934                    &tag_state,
935                    id.clone(),
936                )));
937                let id = SocketType::to_socket_id(id);
938                Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
939            }
940        }
941    }
942
943    /// Returns a borrowed entry at `id` and `addr`.
944    pub fn entry(
945        self,
946        id: &SocketType::Id,
947        addr: &SocketType::Addr,
948    ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
949        let Self(addr_to_state, _) = self;
950        let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
951            Entry::Vacant(_) => return None,
952            Entry::Occupied(o) => o,
953        };
954        let state = SocketType::from_bound_ref(addr_entry.get())?;
955
956        state.contains_id(id).then_some(SocketStateEntry {
957            id: SocketType::to_socket_id(id.clone()),
958            addr_entry,
959            _marker: PhantomData::default(),
960        })
961    }
962
963    /// Removes the entry with `id` and `addr`.
964    pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
965        self.entry(id, addr)
966            .map(|entry| {
967                entry.remove();
968            })
969            .ok_or(NotFoundError)
970    }
971}
972
973/// The error returned when updating the sharing state for a [`SocketMap`] entry
974/// fails.
975#[derive(Debug)]
976pub struct UpdateSharingError;
977
978impl<
979    'a,
980    I: Ip,
981    D: DeviceIdentifier,
982    SocketType: ConvertSocketMapState<I, D, A, S>,
983    A: SocketMapAddrSpec,
984    S: SocketMapStateSpec,
985> SocketStateEntry<'a, I, D, A, S, SocketType>
986where
987    SocketType::Id: Clone,
988{
989    /// Returns this entry's address.
990    pub fn get_addr(&self) -> &SocketType::Addr {
991        let Self { id: _, addr_entry, _marker } = self;
992        SocketType::from_addr_vec_ref(addr_entry.key())
993    }
994
995    /// Returns this entry's identifier.
996    pub fn id(&self) -> &SocketType::Id {
997        let Self { id, addr_entry: _, _marker } = self;
998        SocketType::from_socket_id_ref(id)
999    }
1000
1001    /// Attempts to update the address for this entry.
1002    pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
1003        let Self { id, addr_entry, _marker } = self;
1004
1005        let new_addrvec = SocketType::to_addr_vec(&new_addr);
1006        let old_addr = addr_entry.key().clone();
1007        let (addr_state, addr_to_state) = addr_entry.remove_from_map();
1008        let addr_to_state = match addr_to_state.entry(new_addrvec) {
1009            Entry::Occupied(o) => o.into_map(),
1010            Entry::Vacant(v) => {
1011                if v.descendant_counts().len() != 0 {
1012                    v.into_map()
1013                } else {
1014                    let new_addr_entry = v.insert(addr_state);
1015                    return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
1016                }
1017            }
1018        };
1019        let to_restore = addr_state;
1020        // Restore the old state before returning an error.
1021        let addr_entry = match addr_to_state.entry(old_addr) {
1022            Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
1023            Entry::Vacant(v) => v.insert(to_restore),
1024        };
1025        return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
1026    }
1027
1028    /// Removes this entry from the map.
1029    pub fn remove(self) {
1030        let Self { id, mut addr_entry, _marker } = self;
1031        let addr = addr_entry.key().clone();
1032        match addr_entry.map_mut(|value| {
1033            let value = match SocketType::from_bound_mut(value) {
1034                Some(value) => value,
1035                None => unreachable!("found {:?} for address {:?}", value, addr),
1036            };
1037            value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
1038        }) {
1039            RemoveResult::Success => (),
1040            RemoveResult::IsLast => {
1041                let _: Bound<S> = addr_entry.remove();
1042            }
1043        }
1044    }
1045
1046    /// Attempts to update the sharing state for this entry.
1047    pub fn try_update_sharing(
1048        &mut self,
1049        old_sharing_state: &SocketType::SharingState,
1050        new_sharing_state: SocketType::SharingState,
1051    ) -> Result<(), UpdateSharingError>
1052    where
1053        SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1054        S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1055    {
1056        let Self { id, addr_entry, _marker } = self;
1057        let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1058
1059        S::allows_sharing_update(
1060            addr_entry.get_map(),
1061            addr,
1062            old_sharing_state,
1063            &new_sharing_state,
1064        )?;
1065
1066        addr_entry
1067            .map_mut(|value| {
1068                let value = match SocketType::from_bound_mut(value) {
1069                    Some(value) => value,
1070                    // We shouldn't ever be storing listener state in a bound
1071                    // address, or bound state in a listener address. Doing so means
1072                    // we've got a serious bug.
1073                    None => unreachable!("found invalid state {:?}", value),
1074                };
1075
1076                value.try_update_sharing(
1077                    SocketType::from_socket_id_ref(id).clone(),
1078                    &new_sharing_state,
1079                )
1080            })
1081            .map_err(|IncompatibleError| UpdateSharingError)
1082    }
1083}
1084
1085impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1086where
1087    AddrVec<I, D, A>: IterShadows,
1088    S: SocketMapStateSpec,
1089{
1090    /// Returns an iterator over the listeners on the socket map.
1091    pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1092    where
1093        S: SocketMapConflictPolicy<
1094                ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1095                <S as SocketMapStateSpec>::ListenerSharingState,
1096                I,
1097                D,
1098                A,
1099            >,
1100        S::ListenerAddrState:
1101            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1102    {
1103        let Self { addr_to_state } = self;
1104        Sockets(addr_to_state, Default::default())
1105    }
1106
1107    /// Returns a mutable iterator over the listeners on the socket map.
1108    pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1109    where
1110        S: SocketMapConflictPolicy<
1111                ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1112                <S as SocketMapStateSpec>::ListenerSharingState,
1113                I,
1114                D,
1115                A,
1116            >,
1117        S::ListenerAddrState:
1118            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1119    {
1120        let Self { addr_to_state } = self;
1121        Sockets(addr_to_state, Default::default())
1122    }
1123
1124    /// Returns an iterator over the connections on the socket map.
1125    pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1126    where
1127        S: SocketMapConflictPolicy<
1128                ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1129                <S as SocketMapStateSpec>::ConnSharingState,
1130                I,
1131                D,
1132                A,
1133            >,
1134        S::ConnAddrState:
1135            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1136    {
1137        let Self { addr_to_state } = self;
1138        Sockets(addr_to_state, Default::default())
1139    }
1140
1141    /// Returns a mutable iterator over the connections on the socket map.
1142    pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1143    where
1144        S: SocketMapConflictPolicy<
1145                ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1146                <S as SocketMapStateSpec>::ConnSharingState,
1147                I,
1148                D,
1149                A,
1150            >,
1151        S::ConnAddrState:
1152            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1153    {
1154        let Self { addr_to_state } = self;
1155        Sockets(addr_to_state, Default::default())
1156    }
1157
1158    #[cfg(test)]
1159    pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1160        let Self { addr_to_state } = self;
1161        addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1162    }
1163
1164    /// Gets the number of shadower entries for `addr`.
1165    pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1166        let Self { addr_to_state } = self;
1167        addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1168    }
1169}
1170
1171/// The type returned by [`BoundSocketMap::iter_receivers`].
1172pub enum FoundSockets<A, It> {
1173    /// A single recipient was found for the address.
1174    Single(A),
1175    /// Indicates the looked-up address was multicast, and holds an iterator of
1176    /// the found receivers.
1177    Multicast(It),
1178}
1179
1180/// A borrowed entry in a [`BoundSocketMap`].
1181#[allow(missing_docs)]
1182#[derive(Debug)]
1183pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1184    Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1185    Conn(
1186        &'a S::ConnAddrState,
1187        ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1188    ),
1189}
1190
1191impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1192where
1193    I: BroadcastIpExt<Addr: MulticastAddress>,
1194    D: DeviceIdentifier,
1195    A: SocketMapAddrSpec,
1196    S: SocketMapStateSpec
1197        + SocketMapConflictPolicy<
1198            ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1199            <S as SocketMapStateSpec>::ListenerSharingState,
1200            I,
1201            D,
1202            A,
1203        > + SocketMapConflictPolicy<
1204            ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1205            <S as SocketMapStateSpec>::ConnSharingState,
1206            I,
1207            D,
1208            A,
1209        >,
1210{
1211    /// Looks up a connected socket.
1212    ///
1213    /// This is a lightweight version of `iter_receivers()` that doesn't try to
1214    /// lookup listening sockets. It is used for early demux, which applies only
1215    /// to connected sockets.
1216    pub fn lookup_connected(
1217        &self,
1218        (src_ip, src_port): (SocketIpAddr<I::Addr>, A::RemoteIdentifier),
1219        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1220        device: D,
1221    ) -> Option<&'_ S::ConnAddrState> {
1222        let mut addr = ConnAddr {
1223            ip: ConnIpAddr { local: (dst_ip, dst_port), remote: (src_ip, src_port) },
1224            device: Some(device),
1225        };
1226        let entry = self.conns().get_by_addr(&addr);
1227        if entry.is_some() {
1228            return entry;
1229        }
1230        addr.device = None;
1231        self.conns().get_by_addr(&addr)
1232    }
1233
1234    /// Finds the socket(s) that should receive an incoming packet.
1235    ///
1236    /// Uses the provided addresses and receiving device to look up sockets that
1237    /// should receive a matching incoming packet. Returns `None` if no sockets
1238    /// were found, or the results of the lookup.
1239    pub fn iter_receivers(
1240        &self,
1241        (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1242        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1243        device: D,
1244        broadcast: Option<I::BroadcastMarker>,
1245    ) -> Option<
1246        FoundSockets<
1247            AddrEntry<'_, I, D, A, S>,
1248            impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1249        >,
1250    > {
1251        let mut matching_entries = AddrVecIter::with_device(
1252            match (src_ip, src_port) {
1253                (Some(specified_src_ip), Some(src_port)) => {
1254                    ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1255                        .into()
1256                }
1257                _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1258            },
1259            device,
1260        )
1261        .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1262            AddrVec::Listen(l) => {
1263                self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1264            }
1265            AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1266        });
1267
1268        if broadcast.is_some() || dst_ip.addr().is_multicast() {
1269            Some(FoundSockets::Multicast(matching_entries))
1270        } else {
1271            let single_entry: Option<_> = matching_entries.next();
1272            single_entry.map(FoundSockets::Single)
1273        }
1274    }
1275}
1276
1277/// Errors observed by [`SocketMapConflictPolicy`].
1278#[derive(Debug, Eq, PartialEq)]
1279pub enum InsertError {
1280    /// A shadow address exists for the entry.
1281    ShadowAddrExists,
1282    /// Entry already exists.
1283    Exists,
1284    /// Entry would shadow existing entry.
1285    WouldShadowExisting,
1286    /// An indirect conflict was detected.
1287    IndirectConflict,
1288}
1289
1290impl From<InsertError> for LocalAddressError {
1291    fn from(value: InsertError) -> Self {
1292        match value {
1293            InsertError::ShadowAddrExists
1294            | InsertError::Exists
1295            | InsertError::IndirectConflict
1296            | InsertError::WouldShadowExisting => LocalAddressError::AddressInUse,
1297        }
1298    }
1299}
1300
1301/// Helper trait for converting between [`AddrVec`] and [`Bound`] and their
1302/// variants.
1303pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1304    type Id;
1305    type SharingState;
1306    type Addr: Debug;
1307    type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1308
1309    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1310    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1311    fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1312    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1313    fn to_bound(state: Self::AddrState) -> Bound<S>;
1314    fn to_socket_id(id: Self::Id) -> SocketId<S>;
1315    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1316}
1317
1318impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1319    ConvertSocketMapState<I, D, A, S> for Listener
1320{
1321    type Id = S::ListenerId;
1322    type SharingState = S::ListenerSharingState;
1323    type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1324    type AddrState = S::ListenerAddrState;
1325    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1326        AddrVec::Listen(addr.clone())
1327    }
1328
1329    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1330        match addr {
1331            AddrVec::Listen(l) => l,
1332            AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1333        }
1334    }
1335
1336    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1337        match bound {
1338            Bound::Listen(l) => Some(l),
1339            Bound::Conn(_c) => None,
1340        }
1341    }
1342
1343    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1344        match bound {
1345            Bound::Listen(l) => Some(l),
1346            Bound::Conn(_c) => None,
1347        }
1348    }
1349
1350    fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1351        Bound::Listen(state)
1352    }
1353    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1354        match id {
1355            SocketId::Listener(id) => id,
1356            SocketId::Connection(_) => unreachable!("connection ID for listener"),
1357        }
1358    }
1359    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1360        SocketId::Listener(id)
1361    }
1362}
1363
1364impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1365    ConvertSocketMapState<I, D, A, S> for Connection
1366{
1367    type Id = S::ConnId;
1368    type SharingState = S::ConnSharingState;
1369    type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1370    type AddrState = S::ConnAddrState;
1371    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1372        AddrVec::Conn(addr.clone())
1373    }
1374
1375    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1376        match addr {
1377            AddrVec::Conn(c) => c,
1378            AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1379        }
1380    }
1381
1382    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1383        match bound {
1384            Bound::Listen(_l) => None,
1385            Bound::Conn(c) => Some(c),
1386        }
1387    }
1388
1389    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1390        match bound {
1391            Bound::Listen(_l) => None,
1392            Bound::Conn(c) => Some(c),
1393        }
1394    }
1395
1396    fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1397        Bound::Conn(state)
1398    }
1399
1400    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1401        match id {
1402            SocketId::Connection(id) => id,
1403            SocketId::Listener(_) => unreachable!("listener ID for connection"),
1404        }
1405    }
1406    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1407        SocketId::Connection(id)
1408    }
1409}
1410
1411/// An identifier of a sharing domain used for SO_REUSEPORT.
1412#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)]
1413pub struct SharingDomain(u64);
1414
1415impl SharingDomain {
1416    /// Creates a new instance with the specified ID. Caller must ensure that the `id`
1417    /// uniquely identifies the sharing domain and that the client is authorized to use it,
1418    /// e.g. on Fuchsia the ID is the KOID of a handle provided by the client.
1419    pub const fn new(id: u64) -> Self {
1420        SharingDomain(id)
1421    }
1422}
1423
1424/// A value of the SO_REUSEPORT option. Also encodes the sharing domain, which allows
1425/// to ensure that only sockets in the same domain can share ports.
1426#[derive(Default, Debug, Eq, PartialEq, Clone, Copy, Hash)]
1427pub enum ReusePortOption {
1428    /// The option is disabled.
1429    #[default]
1430    Disabled,
1431
1432    /// The option is enabled: the port is shareable with other sockets in the
1433    /// same sharing domain.
1434    Enabled(SharingDomain),
1435}
1436
1437impl ReusePortOption {
1438    /// Returns `true` if the option is enabled.
1439    pub fn is_enabled(&self) -> bool {
1440        matches!(self, ReusePortOption::Enabled(_))
1441    }
1442
1443    /// Returns `true` if the socket is shareable with a socket with the
1444    /// specified value of the SO_REUSEPORT option.
1445    pub fn is_shareable_with(&self, other: &Self) -> bool {
1446        match (self, other) {
1447            (ReusePortOption::Enabled(domain1), ReusePortOption::Enabled(domain2)) => {
1448                domain1 == domain2
1449            }
1450            _ => false,
1451        }
1452    }
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457    use alloc::vec;
1458    use alloc::vec::Vec;
1459
1460    use assert_matches::assert_matches;
1461    use net_declare::{net_ip_v4, net_ip_v6};
1462    use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1463    use netstack3_hashmap::HashSet;
1464    use test_case::test_case;
1465
1466    use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1467    use crate::testutil::set_logger_for_test;
1468
1469    use super::*;
1470
1471    #[test_case(net_ip_v4!("8.8.8.8"))]
1472    #[test_case(net_ip_v4!("127.0.0.1"))]
1473    #[test_case(net_ip_v4!("127.0.8.9"))]
1474    #[test_case(net_ip_v4!("224.1.2.3"))]
1475    fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1476        // No IPv4 addresses are allowed to have a zone.
1477        let addr = SpecifiedAddr::new(addr).unwrap();
1478        assert_eq!(addr.must_have_zone(), false);
1479    }
1480
1481    #[test_case(net_ip_v6!("1::2:3"), false)]
1482    #[test_case(net_ip_v6!("::1"), false; "localhost")]
1483    #[test_case(net_ip_v6!("1::"), false)]
1484    #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1485    #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1486    #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1487    #[test_case(net_ip_v6!("fe80::1"), true)]
1488    fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1489        // Only link-local unicast and multicast addresses are allowed to have
1490        // zones.
1491        let addr = SpecifiedAddr::new(addr).unwrap();
1492        assert_eq!(addr.must_have_zone(), must_have);
1493    }
1494
1495    #[test]
1496    fn try_into_null_zoned_ipv6() {
1497        assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1498        let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1499        const ZONE: u32 = 5;
1500        assert_eq!(
1501            zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1502            Some(AddrAndZone::new(zoned, ZONE).unwrap())
1503        );
1504    }
1505
1506    enum FakeSpec {}
1507
1508    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1509    struct Listener(usize);
1510
1511    #[derive(PartialEq, Eq, Debug, Copy, Clone)]
1512    struct SharingState {
1513        tag: char,
1514        shared: bool,
1515    }
1516
1517    impl SharingState {
1518        fn exclusive(tag: char) -> Self {
1519            Self { tag, shared: false }
1520        }
1521
1522        fn shared(tag: char) -> Self {
1523            Self { tag, shared: true }
1524        }
1525    }
1526
1527    impl SharingState {
1528        fn can_share_with(&self, other: &Self) -> bool {
1529            self.tag == other.tag && self.shared && other.shared
1530        }
1531    }
1532
1533    #[derive(PartialEq, Eq, Debug)]
1534    struct Multiple<T> {
1535        sharing_state: SharingState,
1536        entries: Vec<T>,
1537    }
1538
1539    impl<T> Multiple<T> {
1540        fn new_exclusive(tag: char, entries: Vec<T>) -> Self {
1541            Self { sharing_state: SharingState { tag, shared: false }, entries }
1542        }
1543    }
1544
1545    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1546    struct Conn(usize);
1547
1548    enum FakeAddrSpec {}
1549
1550    impl SocketMapAddrSpec for FakeAddrSpec {
1551        type LocalIdentifier = NonZeroU16;
1552        type RemoteIdentifier = ();
1553    }
1554
1555    impl SocketMapStateSpec for FakeSpec {
1556        type AddrVecTag = SharingState;
1557
1558        type ListenerId = Listener;
1559        type ConnId = Conn;
1560
1561        type ListenerSharingState = SharingState;
1562        type ConnSharingState = SharingState;
1563
1564        type ListenerAddrState = Multiple<Listener>;
1565        type ConnAddrState = Multiple<Conn>;
1566
1567        fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1568            state.sharing_state
1569        }
1570
1571        fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1572            state.sharing_state
1573        }
1574    }
1575
1576    type FakeBoundSocketMap =
1577        BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1578
1579    /// Generator for unique socket IDs that don't have any state.
1580    ///
1581    /// Calling [`FakeSocketIdGen::next`] returns a unique ID.
1582    #[derive(Default)]
1583    struct FakeSocketIdGen {
1584        next_id: usize,
1585    }
1586
1587    impl FakeSocketIdGen {
1588        fn next(&mut self) -> usize {
1589            let next_next_id = self.next_id + 1;
1590            core::mem::replace(&mut self.next_id, next_next_id)
1591        }
1592    }
1593
1594    impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1595        type Id = I;
1596        type SharingState = SharingState;
1597        type Inserter<'a>
1598            = &'a mut Vec<I>
1599        where
1600            I: 'a;
1601
1602        fn new(sharing_state: &SharingState, id: I) -> Self {
1603            Self { sharing_state: *sharing_state, entries: vec![id] }
1604        }
1605
1606        fn contains_id(&self, id: &Self::Id) -> bool {
1607            self.entries.contains(id)
1608        }
1609
1610        fn try_get_inserter<'a, 'b>(
1611            &'b mut self,
1612            new_sharing_state: &'a SharingState,
1613        ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1614            (self.sharing_state == *new_sharing_state)
1615                .then_some(&mut self.entries)
1616                .ok_or(IncompatibleError)
1617        }
1618
1619        fn could_insert(&self, new_sharing_state: &SharingState) -> Result<(), IncompatibleError> {
1620            (self.sharing_state == *new_sharing_state).then_some(()).ok_or(IncompatibleError)
1621        }
1622
1623        fn remove_by_id(&mut self, id: I) -> RemoveResult {
1624            let index = self.entries.iter().position(|i| i == &id).expect("did not find id");
1625            let _: I = self.entries.swap_remove(index);
1626            if self.entries.is_empty() { RemoveResult::IsLast } else { RemoveResult::Success }
1627        }
1628    }
1629
1630    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1631        SocketMapConflictPolicy<A, SharingState, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1632        for FakeSpec
1633    {
1634        fn check_insert_conflicts(
1635            new_sharing_state: &SharingState,
1636            addr: &A,
1637            socketmap: &SocketMap<
1638                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1639                Bound<FakeSpec>,
1640            >,
1641        ) -> Result<(), InsertError> {
1642            let dest: AddrVec<_, _, _> = addr.clone().into();
1643            if dest.iter_shadows().any(|a| {
1644                let entry = socketmap.get(&a);
1645                match entry {
1646                    Some(Bound::Listen(Multiple { sharing_state, .. }))
1647                    | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1648                        !sharing_state.can_share_with(new_sharing_state)
1649                    }
1650                    None => false,
1651                }
1652            }) {
1653                return Err(InsertError::ShadowAddrExists);
1654            }
1655
1656            match socketmap.get(&dest) {
1657                Some(Bound::Listen(Multiple { sharing_state, .. }))
1658                | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1659                    // Require that all sockets inserted in a `Multiple` entry
1660                    // have the same sharing state.
1661                    if sharing_state != new_sharing_state {
1662                        return Err(InsertError::Exists);
1663                    }
1664                }
1665                None => (),
1666            }
1667
1668            if socketmap
1669                .descendant_counts(&dest)
1670                .any(|(sharing_state, _count)| !sharing_state.can_share_with(new_sharing_state))
1671            {
1672                Err(InsertError::WouldShadowExisting)
1673            } else {
1674                Ok(())
1675            }
1676        }
1677    }
1678
1679    impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1680        fn try_update_sharing(
1681            &mut self,
1682            id: Self::Id,
1683            new_sharing_state: &Self::SharingState,
1684        ) -> Result<(), IncompatibleError> {
1685            if self.sharing_state == *new_sharing_state {
1686                return Ok(());
1687            }
1688
1689            // Preserve the invariant that all sockets inserted in a `Multiple`
1690            // entry have the same sharing state. That means we can't change
1691            // the sharing state of all the sockets at the address unless there
1692            // is exactly one!
1693            if self.entries.len() != 1 {
1694                return Err(IncompatibleError);
1695            }
1696            assert!(self.entries.contains(&id));
1697            self.sharing_state = *new_sharing_state;
1698            Ok(())
1699        }
1700    }
1701
1702    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1703        SocketMapUpdateSharingPolicy<
1704            A,
1705            SharingState,
1706            Ipv4,
1707            FakeWeakDeviceId<FakeDeviceId>,
1708            FakeAddrSpec,
1709        > for FakeSpec
1710    {
1711        fn allows_sharing_update(
1712            _socketmap: &SocketMap<
1713                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1714                Bound<Self>,
1715            >,
1716            _addr: &A,
1717            _old_sharing: &SharingState,
1718            _new_sharing_state: &SharingState,
1719        ) -> Result<(), UpdateSharingError> {
1720            Ok(())
1721        }
1722    }
1723
1724    const LISTENER_ADDR: ListenerAddr<
1725        ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1726        FakeWeakDeviceId<FakeDeviceId>,
1727    > = ListenerAddr {
1728        ip: ListenerIpAddr {
1729            addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1730            identifier: NonZeroU16::new(1).unwrap(),
1731        },
1732        device: None,
1733    };
1734
1735    const CONN_ADDR: ConnAddr<
1736        ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1737        FakeWeakDeviceId<FakeDeviceId>,
1738    > = ConnAddr {
1739        ip: ConnIpAddr {
1740            local: (
1741                unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1742                NonZeroU16::new(1).unwrap(),
1743            ),
1744            remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1745        },
1746        device: None,
1747    };
1748
1749    #[test]
1750    fn bound_insert_get_remove_listener() {
1751        set_logger_for_test();
1752        let mut bound = FakeBoundSocketMap::default();
1753        let mut fake_id_gen = FakeSocketIdGen::default();
1754
1755        let addr = LISTENER_ADDR;
1756
1757        let id = {
1758            let entry = bound
1759                .listeners_mut()
1760                .try_insert(addr, SharingState::exclusive('v'), Listener(fake_id_gen.next()))
1761                .unwrap();
1762            assert_eq!(entry.get_addr(), &addr);
1763            entry.id().clone()
1764        };
1765
1766        assert_eq!(
1767            bound.listeners().get_by_addr(&addr),
1768            Some(&Multiple::new_exclusive('v', vec![id]))
1769        );
1770
1771        assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1772        assert_eq!(bound.listeners().get_by_addr(&addr), None);
1773    }
1774
1775    #[test]
1776    fn bound_insert_get_remove_conn() {
1777        set_logger_for_test();
1778        let mut bound = FakeBoundSocketMap::default();
1779        let mut fake_id_gen = FakeSocketIdGen::default();
1780
1781        let addr = CONN_ADDR;
1782
1783        let id = {
1784            let entry = bound
1785                .conns_mut()
1786                .try_insert(addr, SharingState::exclusive('v'), Conn(fake_id_gen.next()))
1787                .unwrap();
1788            assert_eq!(entry.get_addr(), &addr);
1789            entry.id().clone()
1790        };
1791
1792        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('v', vec![id])));
1793
1794        assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1795        assert_eq!(bound.conns().get_by_addr(&addr), None);
1796    }
1797
1798    #[test]
1799    fn bound_iter_addrs() {
1800        set_logger_for_test();
1801        let mut bound = FakeBoundSocketMap::default();
1802        let mut fake_id_gen = FakeSocketIdGen::default();
1803
1804        let listener_addrs = [
1805            (Some(net_ip_v4!("1.1.1.1")), 1),
1806            (Some(net_ip_v4!("2.2.2.2")), 2),
1807            (Some(net_ip_v4!("1.1.1.1")), 3),
1808            (None, 4),
1809        ]
1810        .map(|(ip, identifier)| ListenerAddr {
1811            device: None,
1812            ip: ListenerIpAddr {
1813                addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1814                identifier: NonZeroU16::new(identifier).unwrap(),
1815            },
1816        });
1817        let conn_addrs = [
1818            (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1819            (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1820        ]
1821        .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1822            ip: ConnIpAddr {
1823                local: (
1824                    SocketIpAddr::new(local_ip).unwrap(),
1825                    NonZeroU16::new(local_identifier).unwrap(),
1826                ),
1827                remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1828            },
1829            device: None,
1830        });
1831
1832        for addr in listener_addrs.iter().cloned() {
1833            let _entry = bound
1834                .listeners_mut()
1835                .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1836                .unwrap();
1837        }
1838        for addr in conn_addrs.iter().cloned() {
1839            let _entry = bound
1840                .conns_mut()
1841                .try_insert(addr, SharingState::exclusive('a'), Conn(fake_id_gen.next()))
1842                .unwrap();
1843        }
1844        let expected_addrs = listener_addrs
1845            .into_iter()
1846            .map(Into::into)
1847            .chain(conn_addrs.into_iter().map(Into::into))
1848            .collect::<HashSet<_>>();
1849
1850        assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1851    }
1852
1853    #[test]
1854    fn try_insert_with_callback_not_called_on_error() {
1855        // TODO(https://fxbug.dev/42076891): remove this test along with
1856        // try_insert_with.
1857        set_logger_for_test();
1858        let mut bound = FakeBoundSocketMap::default();
1859        let addr = LISTENER_ADDR;
1860
1861        // Insert a listener so that future calls can conflict.
1862        let _: &Listener = bound
1863            .listeners_mut()
1864            .try_insert(addr, SharingState::exclusive('a'), Listener(0))
1865            .unwrap()
1866            .id();
1867
1868        // All of the below try_insert_with calls should fail, but more
1869        // importantly, they should not call the `make_id` callback (because it
1870        // is only called once success is certain).
1871        fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1872            panic!("should never be called");
1873        }
1874
1875        assert_matches!(
1876            bound.listeners_mut().try_insert_with(
1877                addr,
1878                SharingState::exclusive('b'),
1879                is_never_called
1880            ),
1881            Err(InsertError::Exists)
1882        );
1883        assert_matches!(
1884            bound.listeners_mut().try_insert_with(
1885                ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1886                SharingState::exclusive('b'),
1887                is_never_called
1888            ),
1889            Err(InsertError::ShadowAddrExists)
1890        );
1891        assert_matches!(
1892            bound.conns_mut().try_insert_with(
1893                ConnAddr {
1894                    device: None,
1895                    ip: ConnIpAddr {
1896                        local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1897                        remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1898                    },
1899                },
1900                SharingState::exclusive('b'),
1901                is_never_called,
1902            ),
1903            Err(InsertError::ShadowAddrExists)
1904        );
1905    }
1906
1907    #[test]
1908    fn insert_listener_conflict_with_listener() {
1909        set_logger_for_test();
1910        let mut bound = FakeBoundSocketMap::default();
1911        let mut fake_id_gen = FakeSocketIdGen::default();
1912        let addr = LISTENER_ADDR;
1913
1914        let _: &Listener = bound
1915            .listeners_mut()
1916            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1917            .unwrap()
1918            .id();
1919        assert_matches!(
1920            bound.listeners_mut().try_insert(
1921                addr,
1922                SharingState::exclusive('b'),
1923                Listener(fake_id_gen.next())
1924            ),
1925            Err(InsertError::Exists)
1926        );
1927    }
1928
1929    #[test]
1930    fn insert_listener_conflict_with_shadower() {
1931        set_logger_for_test();
1932        let mut bound = FakeBoundSocketMap::default();
1933        let mut fake_id_gen = FakeSocketIdGen::default();
1934        let addr = LISTENER_ADDR;
1935        let shadows_addr = {
1936            assert_eq!(addr.device, None);
1937            ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1938        };
1939
1940        let _: &Listener = bound
1941            .listeners_mut()
1942            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1943            .unwrap()
1944            .id();
1945        assert_matches!(
1946            bound.listeners_mut().try_insert(
1947                shadows_addr,
1948                SharingState::exclusive('b'),
1949                Listener(fake_id_gen.next())
1950            ),
1951            Err(InsertError::ShadowAddrExists)
1952        );
1953    }
1954
1955    #[test]
1956    fn insert_conn_conflict_with_listener() {
1957        set_logger_for_test();
1958        let mut bound = FakeBoundSocketMap::default();
1959        let mut fake_id_gen = FakeSocketIdGen::default();
1960        let addr = LISTENER_ADDR;
1961        let shadows_addr = ConnAddr {
1962            device: None,
1963            ip: ConnIpAddr {
1964                local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1965                remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1966            },
1967        };
1968
1969        let _: &Listener = bound
1970            .listeners_mut()
1971            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1972            .unwrap()
1973            .id();
1974        assert_matches!(
1975            bound.conns_mut().try_insert(
1976                shadows_addr,
1977                SharingState::exclusive('b'),
1978                Conn(fake_id_gen.next())
1979            ),
1980            Err(InsertError::ShadowAddrExists)
1981        );
1982    }
1983
1984    #[test]
1985    fn insert_and_remove_listener() {
1986        set_logger_for_test();
1987        let mut bound = FakeBoundSocketMap::default();
1988        let mut fake_id_gen = FakeSocketIdGen::default();
1989        let addr = LISTENER_ADDR;
1990
1991        let a = bound
1992            .listeners_mut()
1993            .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1994            .unwrap()
1995            .id()
1996            .clone();
1997        let b = bound
1998            .listeners_mut()
1999            .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
2000            .unwrap()
2001            .id()
2002            .clone();
2003        assert_ne!(a, b);
2004
2005        assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
2006        assert_eq!(
2007            bound.listeners().get_by_addr(&addr),
2008            Some(&Multiple::new_exclusive('x', vec![b]))
2009        );
2010    }
2011
2012    #[test]
2013    fn insert_and_remove_conn() {
2014        set_logger_for_test();
2015        let mut bound = FakeBoundSocketMap::default();
2016        let mut fake_id_gen = FakeSocketIdGen::default();
2017        let addr = CONN_ADDR;
2018
2019        let a = bound
2020            .conns_mut()
2021            .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
2022            .unwrap()
2023            .id()
2024            .clone();
2025        let b = bound
2026            .conns_mut()
2027            .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
2028            .unwrap()
2029            .id()
2030            .clone();
2031        assert_ne!(a, b);
2032
2033        assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
2034        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('x', vec![b])));
2035    }
2036
2037    #[test]
2038    fn update_listener_to_shadowed_addr_fails() {
2039        let mut bound = FakeBoundSocketMap::default();
2040        let mut fake_id_gen = FakeSocketIdGen::default();
2041
2042        let first_addr = LISTENER_ADDR;
2043        let second_addr = ListenerAddr {
2044            ip: ListenerIpAddr {
2045                addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
2046                ..LISTENER_ADDR.ip
2047            },
2048            ..LISTENER_ADDR
2049        };
2050        let both_shadow = ListenerAddr {
2051            ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
2052            device: None,
2053        };
2054
2055        let first = bound
2056            .listeners_mut()
2057            .try_insert(first_addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
2058            .unwrap()
2059            .id()
2060            .clone();
2061        let second = bound
2062            .listeners_mut()
2063            .try_insert(second_addr, SharingState::exclusive('b'), Listener(fake_id_gen.next()))
2064            .unwrap()
2065            .id()
2066            .clone();
2067
2068        // Moving from (1, "aaa") to (1, None) should fail since it is shadowed
2069        // by (1, "yyy"), and vise versa.
2070        let (ExistsError, entry) = bound
2071            .listeners_mut()
2072            .entry(&second, &second_addr)
2073            .unwrap()
2074            .try_update_addr(both_shadow)
2075            .expect_err("update should fail");
2076
2077        // The entry should correspond to `second`.
2078        assert_eq!(entry.id(), &second);
2079        drop(entry);
2080
2081        let (ExistsError, entry) = bound
2082            .listeners_mut()
2083            .entry(&first, &first_addr)
2084            .unwrap()
2085            .try_update_addr(both_shadow)
2086            .expect_err("update should fail");
2087        assert_eq!(entry.get_addr(), &first_addr);
2088    }
2089
2090    #[test]
2091    fn nonexistent_conn_entry() {
2092        let mut map = FakeBoundSocketMap::default();
2093        let mut fake_id_gen = FakeSocketIdGen::default();
2094        let addr = CONN_ADDR;
2095        let conn_id = map
2096            .conns_mut()
2097            .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2098            .expect("failed to insert")
2099            .id()
2100            .clone();
2101        assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
2102
2103        assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
2104    }
2105
2106    #[test]
2107    fn update_conn_sharing() {
2108        let mut map = FakeBoundSocketMap::default();
2109        let mut fake_id_gen = FakeSocketIdGen::default();
2110        let addr = CONN_ADDR;
2111        let mut entry = map
2112            .conns_mut()
2113            .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2114            .expect("failed to insert");
2115
2116        entry
2117            .try_update_sharing(&SharingState::exclusive('a'), SharingState::exclusive('d'))
2118            .expect("worked");
2119        // Updating sharing is only allowed if there are no other occupants at
2120        // the address.
2121        let mut second_conn = map
2122            .conns_mut()
2123            .try_insert(addr.clone(), SharingState::exclusive('d'), Conn(fake_id_gen.next()))
2124            .expect("can insert");
2125        assert_matches!(
2126            second_conn
2127                .try_update_sharing(&SharingState::exclusive('d'), SharingState::exclusive('e')),
2128            Err(UpdateSharingError)
2129        );
2130    }
2131
2132    #[test]
2133    fn lookup_connected() {
2134        let mut map = FakeBoundSocketMap::default();
2135        let mut fake_id_gen = FakeSocketIdGen::default();
2136
2137        let sharing_state = SharingState::shared('a');
2138
2139        let device_id = FakeWeakDeviceId(FakeDeviceId);
2140        let entry1 = map
2141            .conns_mut()
2142            .try_insert(CONN_ADDR, sharing_state, Conn(fake_id_gen.next()))
2143            .expect("failed to insert")
2144            .id()
2145            .clone();
2146        let conn = map
2147            .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2148            .expect("lookup should succeed");
2149        assert!(conn.contains_id(&entry1));
2150
2151        // Add a second entry with a device ID. This one should be preferred
2152        // over the first one.
2153        let addr_with_device = ConnAddr { device: Some(device_id), ..CONN_ADDR };
2154        let entry2 = map
2155            .conns_mut()
2156            .try_insert(addr_with_device, sharing_state, Conn(fake_id_gen.next()))
2157            .expect("failed to insert")
2158            .id()
2159            .clone();
2160        let conn = map
2161            .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2162            .expect("lookup should succeed");
2163        assert!(conn.contains_id(&entry2));
2164    }
2165}