netstack3_base/socket/
base.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! General-purpose socket utilities common to device layer and IP layer
6//! sockets.
7
8use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17    AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19use thiserror::Error;
20
21use crate::data_structures::socketmap::{
22    Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
23};
24use crate::device::{
25    DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
26};
27use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
28use crate::ip::BroadcastIpExt;
29use crate::socket::address::{
30    AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
31};
32
33/// A dual stack IP extention trait that provides the `OtherVersion` associated
34/// type.
35pub trait DualStackIpExt: Ip {
36    /// The "other" IP version, e.g. [`Ipv4`] for [`Ipv6`] and vice-versa.
37    type OtherVersion: DualStackIpExt<OtherVersion = Self>;
38}
39
40impl DualStackIpExt for Ipv4 {
41    type OtherVersion = Ipv6;
42}
43
44impl DualStackIpExt for Ipv6 {
45    type OtherVersion = Ipv4;
46}
47
48/// A tuple of values for `T` for both `I` and `I::OtherVersion`.
49pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
50    this_stack: <T as GenericOverIp<I>>::Type,
51    other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
52    _marker: IpVersionMarker<I>,
53}
54
55impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
56    /// Creates a new tuple with `this_stack` and `other_stack` values.
57    pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
58    where
59        T: GenericOverIp<I, Type = T>,
60    {
61        Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
62    }
63
64    /// Retrieves `(this_stack, other_stack)` from the tuple.
65    pub fn into_inner(
66        self,
67    ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
68        let Self { this_stack, other_stack, _marker } = self;
69        (this_stack, other_stack)
70    }
71
72    /// Retrieves `this_stack` from the tuple.
73    pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
74        self.this_stack
75    }
76
77    /// Borrows `this_stack` from the tuple.
78    pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
79        &self.this_stack
80    }
81
82    /// Retrieves `other_stack` from the tuple.
83    pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
84        self.other_stack
85    }
86
87    /// Borrows `other_stack` from the tuple.
88    pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
89        &self.other_stack
90    }
91
92    /// Flips the types, making `this_stack` `other_stack` and vice-versa.
93    pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
94        let Self { this_stack, other_stack, _marker } = self;
95        DualStackTuple {
96            this_stack: other_stack,
97            other_stack: this_stack,
98            _marker: IpVersionMarker::new(),
99        }
100    }
101
102    /// Casts to IP version `X`.
103    ///
104    /// Given `DualStackTuple` contains complete information for both IP
105    /// versions, it can be easily cast into an arbitrary `X` IP version.
106    ///
107    /// This can be used to tie together type parameters when dealing with dual
108    /// stack sockets. For example, a `DualStackTuple` defined for `SockI` can
109    /// be cast to any `WireI`.
110    pub fn cast<X>(self) -> DualStackTuple<X, T>
111    where
112        X: DualStackIpExt,
113        T: GenericOverIp<X>
114            + GenericOverIp<X::OtherVersion>
115            + GenericOverIp<Ipv4>
116            + GenericOverIp<Ipv6>,
117    {
118        I::map_ip_in(
119            self,
120            |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
121            |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
122        )
123    }
124}
125
126impl<
127    I: DualStackIpExt,
128    NewIp: DualStackIpExt,
129    T: GenericOverIp<NewIp>
130        + GenericOverIp<NewIp::OtherVersion>
131        + GenericOverIp<I>
132        + GenericOverIp<I::OtherVersion>,
133> GenericOverIp<NewIp> for DualStackTuple<I, T>
134{
135    type Type = DualStackTuple<NewIp, T>;
136}
137
138/// Extension trait for `Ip` providing socket-specific functionality.
139pub trait SocketIpExt: Ip {
140    /// `Self::LOOPBACK_ADDRESS`, but wrapped in the `SocketIpAddr` type.
141    const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
142        // SAFETY: The loopback address is a valid SocketIpAddr, as verified
143        // in the `loopback_addr_is_valid_socket_addr` test.
144        SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
145    };
146}
147
148impl<I: Ip> SocketIpExt for I {}
149
150#[cfg(test)]
151mod socket_ip_ext_test {
152    use super::*;
153    use ip_test_macro::ip_test;
154
155    #[ip_test(I)]
156    fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
157        // `LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR is defined with the "unchecked"
158        // constructor (which supports const construction). Verify here that the
159        // addr actually satisfies all the requirements (protecting against far
160        // away changes)
161        let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
162            .expect("loopback address should be a valid SocketIpAddr");
163    }
164}
165
166/// State belonging to either IP stack.
167///
168/// Like `[either::Either]`, but with more helpful variant names.
169///
170/// Note that this type is not optimally type-safe, because `T` and `O` are not
171/// bound by `IP` and `IP::OtherVersion`, respectively. In many cases it may be
172/// more appropriate to define a one-off enum parameterized over `I: Ip`.
173#[derive(Debug, PartialEq, Eq)]
174pub enum EitherStack<T, O> {
175    /// In the current stack version.
176    ThisStack(T),
177    /// In the other version of the stack.
178    OtherStack(O),
179}
180
181impl<T, O> Clone for EitherStack<T, O>
182where
183    T: Clone,
184    O: Clone,
185{
186    #[cfg_attr(feature = "instrumented", track_caller)]
187    fn clone(&self) -> Self {
188        match self {
189            Self::ThisStack(t) => Self::ThisStack(t.clone()),
190            Self::OtherStack(t) => Self::OtherStack(t.clone()),
191        }
192    }
193}
194
195/// Control flow type containing either a dual-stack or non-dual-stack context.
196///
197/// This type exists to provide nice names to the result of
198/// [`BoundStateContext::dual_stack_context`], and to allow generic code to
199/// match on when checking whether a socket protocol and IP version support
200/// dual-stack operation. If dual-stack operation is supported, a
201/// [`MaybeDualStack::DualStack`] value will be held, otherwise a `NonDualStack`
202/// value.
203///
204/// Note that the templated types to not have trait bounds; those are provided
205/// by the trait with the `dual_stack_context` function.
206///
207/// In monomorphized code, this type frequently has exactly one template
208/// parameter that is uninstantiable (it contains an instance of
209/// [`core::convert::Infallible`] or some other empty enum, or a reference to
210/// the same)! That lets the compiler optimize it out completely, creating no
211/// actual runtime overhead.
212#[derive(Debug)]
213#[allow(missing_docs)]
214pub enum MaybeDualStack<DS, NDS> {
215    DualStack(DS),
216    NotDualStack(NDS),
217}
218
219// Implement `GenericOverIp` for a `MaybeDualStack` whose `DS` and `NDS` also
220// implement `GenericOverIp`.
221impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
222    for MaybeDualStack<DS, NDS>
223{
224    type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
225}
226
227/// An error encountered while enabling or disabling dual-stack operation.
228#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
229#[generic_over_ip()]
230pub enum SetDualStackEnabledError {
231    /// A socket can only have dual stack enabled or disabled while unbound.
232    #[error("a socket can only have dual stack enabled or disabled while unbound")]
233    SocketIsBound,
234    /// The socket's protocol is not dual stack capable.
235    #[error(transparent)]
236    NotCapable(#[from] NotDualStackCapableError),
237}
238
239/// An error encountered when attempting to perform dual stack operations on
240/// socket with a non dual stack capable protocol.
241#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
242#[generic_over_ip()]
243#[error("socket's protocol is not dual-stack capable")]
244pub struct NotDualStackCapableError;
245
246/// Describes which direction(s) of the data path should be shut down.
247#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
248pub struct Shutdown {
249    /// True if the send path is shut down for the owning socket.
250    ///
251    /// If this is true, the socket should not be able to send packets.
252    pub send: bool,
253    /// True if the receive path is shut down for the owning socket.
254    ///
255    /// If this is true, the socket should not be able to receive packets.
256    pub receive: bool,
257}
258
259/// Which direction(s) to shut down for a socket.
260#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
261#[generic_over_ip()]
262pub enum ShutdownType {
263    /// Prevent sending packets on the socket.
264    Send,
265    /// Prevent receiving packets on the socket.
266    Receive,
267    /// Prevent sending and receiving packets on the socket.
268    SendAndReceive,
269}
270
271impl ShutdownType {
272    /// Returns a tuple of booleans for `(shutdown_send, shutdown_receive)`.
273    pub fn to_send_receive(&self) -> (bool, bool) {
274        match self {
275            Self::Send => (true, false),
276            Self::Receive => (false, true),
277            Self::SendAndReceive => (true, true),
278        }
279    }
280
281    /// Creates a [`ShutdownType`] from a pair of bools for send and receive.
282    pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
283        match (send, receive) {
284            (true, false) => Some(Self::Send),
285            (false, true) => Some(Self::Receive),
286            (true, true) => Some(Self::SendAndReceive),
287            (false, false) => None,
288        }
289    }
290}
291
292/// Extensions to IP Address witnesses useful in the context of sockets.
293pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
294    /// Determines whether the provided address is underspecified by itself.
295    ///
296    /// Some addresses are ambiguous and so must have a zone identifier in order
297    /// to be used in a socket address. This function returns true for IPv6
298    /// link-local addresses and false for all others.
299    fn must_have_zone(&self) -> bool
300    where
301        Self: Copy,
302    {
303        self.try_into_null_zoned().is_some()
304    }
305
306    /// Converts into a [`AddrAndZone<A, ()>`] if the address requires a zone.
307    ///
308    /// Otherwise returns `None`.
309    fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
310        if self.get().is_loopback() {
311            return None;
312        }
313        AddrAndZone::new(self, ())
314    }
315}
316
317impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
318
319/// An extention trait for [`ZonedAddr`].
320pub trait SocketZonedAddrExt<W, A, D> {
321    /// Returns the address and device that should be used for a socket.
322    ///
323    /// Given an address for a socket and an optional device that the socket is
324    /// already bound on, returns the address and device that should be used
325    /// for the socket. If `addr` and `device` require inconsistent devices,
326    /// or if `addr` requires a zone but there is none specified (by `addr` or
327    /// `device`), an error is returned.
328    fn resolve_addr_with_device(
329        self,
330        device: Option<D::Weak>,
331    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
332    where
333        D: StrongDeviceIdentifier;
334}
335
336impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
337where
338    W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
339    A: IpAddress,
340{
341    fn resolve_addr_with_device(
342        self,
343        device: Option<D::Weak>,
344    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
345    where
346        D: StrongDeviceIdentifier,
347    {
348        let (addr, zone) = self.into_addr_zone();
349        let device = match (zone, device) {
350            (Some(zone), Some(device)) => {
351                if device != zone {
352                    return Err(ZonedAddressError::DeviceZoneMismatch);
353                }
354                Some(EitherDeviceId::Strong(zone))
355            }
356            (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
357            (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
358            (None, None) => {
359                if addr.as_ref().must_have_zone() {
360                    return Err(ZonedAddressError::RequiredZoneNotProvided);
361                } else {
362                    None
363                }
364            }
365        };
366        Ok((addr, device))
367    }
368}
369
370/// A helper type to verify if applying socket updates is allowed for a given
371/// current state.
372///
373/// The fields in `SocketDeviceUpdate` define the current state,
374/// [`SocketDeviceUpdate::try_update`] applies the verification logic.
375pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
376    /// The current local IP address.
377    pub local_ip: Option<&'a SpecifiedAddr<A>>,
378    /// The current remote IP address.
379    pub remote_ip: Option<&'a SpecifiedAddr<A>>,
380    /// The currently bound device.
381    pub old_device: Option<&'a D>,
382}
383
384impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
385    /// Checks if an update from `old_device` to `new_device` is allowed,
386    /// returning an error if not.
387    pub fn check_update<N>(
388        self,
389        new_device: Option<&N>,
390    ) -> Result<(), SocketDeviceUpdateNotAllowedError>
391    where
392        D: PartialEq<N>,
393    {
394        let Self { local_ip, remote_ip, old_device } = self;
395        let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
396            || remote_ip.is_some_and(|a| a.must_have_zone());
397
398        if !must_have_zone {
399            return Ok(());
400        }
401
402        let old_device = old_device.unwrap_or_else(|| {
403            panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
404        });
405
406        if new_device.is_some_and(|new_device| old_device == new_device) {
407            Ok(())
408        } else {
409            Err(SocketDeviceUpdateNotAllowedError)
410        }
411    }
412}
413
414/// The device can't be updated on a socket.
415pub struct SocketDeviceUpdateNotAllowedError;
416
417/// Specification for the identifiers in an [`AddrVec`].
418///
419/// This is a convenience trait for bundling together the local and remote
420/// identifiers for a protocol.
421pub trait SocketMapAddrSpec {
422    /// The local identifier portion of a socket address.
423    type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
424    /// The remote identifier portion of a socket address.
425    type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
426}
427
428/// Information about the address in a [`ListenerAddr`].
429pub struct ListenerAddrInfo {
430    /// Whether the address has a device bound.
431    pub has_device: bool,
432    /// Whether the listener is on a specified address (as opposed to a blanket
433    /// listener).
434    pub specified_addr: bool,
435}
436
437impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
438    pub(crate) fn info(&self) -> ListenerAddrInfo {
439        let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
440        ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
441    }
442}
443
444/// Specifies the types parameters for [`BoundSocketMap`] state as a single bundle.
445pub trait SocketMapStateSpec {
446    /// The tag value of a socket address vector entry.
447    ///
448    /// These values are derived from [`Self::ListenerAddrState`] and
449    /// [`Self::ConnAddrState`].
450    type AddrVecTag: Eq + Copy + Debug + 'static;
451
452    /// Returns a the tag for a listener in the socket map.
453    fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
454
455    /// Returns a the tag for a connected socket in the socket map.
456    fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
457
458    /// An identifier for a listening socket.
459    type ListenerId: Clone + Debug;
460    /// An identifier for a connected socket.
461    type ConnId: Clone + Debug;
462
463    /// The state stored for a listening socket that is used to determine
464    /// whether sockets can share an address.
465    type ListenerSharingState: Clone + Debug;
466
467    /// The state stored for a connected socket that is used to determine
468    /// whether sockets can share an address.
469    type ConnSharingState: Clone + Debug;
470
471    /// The state stored for a listener socket address.
472    type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
473        + Debug;
474
475    /// The state stored for a connected socket address.
476    type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
477        + Debug;
478}
479
480/// Error returned by implementations of [`SocketMapAddrStateSpec`] to indicate
481/// incompatible changes to a socket map.
482#[derive(Copy, Clone, Debug, Eq, PartialEq)]
483pub struct IncompatibleError;
484
485/// An inserter into a [`SocketMap`].
486pub trait Inserter<T> {
487    /// Inserts the provided item and consumes `self`.
488    ///
489    /// Inserts a single item and consumes the inserter (thus preventing
490    /// additional insertions).
491    fn insert(self, item: T);
492}
493
494impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
495    fn insert(self, item: T) {
496        self.extend([item])
497    }
498}
499
500impl<T> Inserter<T> for Never {
501    fn insert(self, _: T) {
502        match self {}
503    }
504}
505
506/// Describes an entry in a [`SocketMap`] for a listener or connection address.
507pub trait SocketMapAddrStateSpec {
508    /// The type of ID that can be present at the address.
509    type Id;
510
511    /// The sharing state for the address.
512    ///
513    /// This can be used to determine whether a socket can be inserted at the
514    /// address. Every socket has its own sharing state associated with it,
515    /// though the sharing state is not necessarily stored in the address
516    /// entry.
517    type SharingState;
518
519    /// The type of inserter returned by [`SocketMapAddrStateSpec::try_get_inserter`].
520    type Inserter<'a>: Inserter<Self::Id> + 'a
521    where
522        Self: 'a,
523        Self::Id: 'a;
524
525    /// Creates a new `Self` holding the provided socket with the given new
526    /// sharing state at the specified address.
527    fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
528
529    /// Looks up the ID in self, returning `true` if it is present.
530    fn contains_id(&self, id: &Self::Id) -> bool;
531
532    /// Enables insertion in `self` for a new socket with the provided sharing
533    /// state.
534    ///
535    /// If the new state is incompatible with the existing socket(s),
536    /// implementations of this function should return `Err(IncompatibleError)`.
537    /// If `Ok(x)` is returned, calling `x.insert(y)` will insert `y` into
538    /// `self`.
539    fn try_get_inserter<'a, 'b>(
540        &'b mut self,
541        new_sharing_state: &'a Self::SharingState,
542    ) -> Result<Self::Inserter<'b>, IncompatibleError>;
543
544    /// Returns `Ok` if an entry with the given sharing state could be added
545    /// to `self`.
546    ///
547    /// If this returns `Ok`, `try_get_dest` should succeed.
548    fn could_insert(&self, new_sharing_state: &Self::SharingState)
549    -> Result<(), IncompatibleError>;
550
551    /// Removes the given socket from the existing state.
552    ///
553    /// Implementations should assume that `id` is contained in `self`.
554    fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
555}
556
557/// Provides behavior on updating the sharing state of a [`SocketMap`] entry.
558pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
559    /// Attempts to update the sharing state of the address state with id `id`
560    /// to `new_sharing_state`.
561    fn try_update_sharing(
562        &mut self,
563        id: Self::Id,
564        new_sharing_state: &Self::SharingState,
565    ) -> Result<(), IncompatibleError>;
566}
567
568/// Provides conflict detection for a [`SocketMapStateSpec`].
569pub trait SocketMapConflictPolicy<
570    Addr,
571    SharingState,
572    I: Ip,
573    D: DeviceIdentifier,
574    A: SocketMapAddrSpec,
575>: SocketMapStateSpec
576{
577    /// Checks whether a new socket with the provided state can be inserted at
578    /// the given address in the existing socket map, returning an error
579    /// otherwise.
580    ///
581    /// Implementations of this function should check for any potential
582    /// conflicts that would arise when inserting a socket with state
583    /// `new_sharing_state` into a new or existing entry at `addr` in
584    /// `socketmap`.
585    fn check_insert_conflicts(
586        new_sharing_state: &SharingState,
587        addr: &Addr,
588        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
589    ) -> Result<(), InsertError>;
590}
591
592/// Defines the policy for updating the sharing state of entries in the
593/// [`SocketMap`].
594pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
595    SocketMapConflictPolicy<Addr, SharingState, I, D, A>
596where
597    A: SocketMapAddrSpec,
598{
599    /// Returns whether the entry `addr` in `socketmap` allows the sharing state
600    /// to transition from `old_sharing` to `new_sharing`.
601    fn allows_sharing_update(
602        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
603        addr: &Addr,
604        old_sharing: &SharingState,
605        new_sharing: &SharingState,
606    ) -> Result<(), UpdateSharingError>;
607}
608
609/// A bound socket state that is either a listener or a connection.
610#[derive(Derivative)]
611#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
612#[allow(missing_docs)]
613pub enum Bound<S: SocketMapStateSpec + ?Sized> {
614    Listen(S::ListenerAddrState),
615    Conn(S::ConnAddrState),
616}
617
618/// An "address vector" type that can hold any address in a [`SocketMap`].
619///
620/// This is a "vector" in the mathematical sense, in that it denotes an address
621/// in a space. Here, the space is the possible addresses to which a socket
622/// receiving IP packets can be bound.
623///
624/// `AddrVec`s are used as keys for the `SocketMap` type. Since an incoming
625/// packet can match more than one address, for each incoming packet there is a
626/// set of possible `AddrVec` keys whose entries (sockets) in a `SocketMap`
627/// might receive the packet.
628///
629/// This set of keys can be ordered by precedence as described in the
630/// documentation for [`AddrVecIter`]. Calling [`IterShadows::iter_shadows`] on
631/// an instance will produce the sequence of addresses it has precedence over.
632#[derive(Derivative)]
633#[derivative(
634    Debug(bound = "D: Debug"),
635    Clone(bound = "D: Clone"),
636    Eq(bound = "D: Eq"),
637    PartialEq(bound = "D: PartialEq"),
638    Hash(bound = "D: Hash")
639)]
640#[allow(missing_docs)]
641pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
642    Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
643    Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
644}
645
646impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
647    Tagged<AddrVec<I, D, A>> for Bound<S>
648{
649    type Tag = S::AddrVecTag;
650    fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
651        match (self, address) {
652            (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
653            (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
654                S::connected_tag(device.is_some(), c)
655            }
656            (Bound::Listen(_), AddrVec::Conn(_)) => {
657                unreachable!("found listen state for conn addr")
658            }
659            (Bound::Conn(_), AddrVec::Listen(_)) => {
660                unreachable!("found conn state for listen addr")
661            }
662        }
663    }
664}
665
666impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
667    type IterShadows = AddrVecIter<I, D, A>;
668
669    fn iter_shadows(&self) -> Self::IterShadows {
670        let (socket_ip_addr, device) = match self.clone() {
671            AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
672            AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
673        };
674        let mut iter = match device {
675            Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
676            None => AddrVecIter::without_device(socket_ip_addr),
677        };
678        // Skip the first element, which is always `*self`.
679        assert_eq!(iter.next().as_ref(), Some(self));
680        iter
681    }
682}
683
684/// How a socket is bound on the system.
685#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
686#[allow(missing_docs)]
687pub enum SocketAddrType {
688    AnyListener,
689    SpecificListener,
690    Connected,
691}
692
693impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
694    fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
695        match addr {
696            Some(_) => SocketAddrType::SpecificListener,
697            None => SocketAddrType::AnyListener,
698        }
699    }
700}
701
702impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
703    fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
704        SocketAddrType::Connected
705    }
706}
707
708/// The result of attempting to remove a socket from a collection of sockets.
709pub enum RemoveResult {
710    /// The value was removed successfully.
711    Success,
712    /// The value is the last value in the collection so the entire collection
713    /// should be removed.
714    IsLast,
715}
716
717#[derive(Derivative)]
718#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
719pub enum SocketId<S: SocketMapStateSpec> {
720    Listener(S::ListenerId),
721    Connection(S::ConnId),
722}
723
724/// A map from socket addresses to sockets.
725///
726/// The types of keys and IDs is determined by the [`SocketMapStateSpec`]
727/// parameter. Each listener and connected socket stores additional state.
728/// Listener and connected sockets are keyed independently, but share the same
729/// address vector space. Conflicts are detected on attempted insertion of new
730/// sockets.
731///
732/// Listener addresses map to listener-address-specific state, and likewise
733/// with connected addresses. Depending on protocol (determined by the
734/// `SocketMapStateSpec` protocol), these address states can hold one or more
735/// socket identifiers (e.g. UDP sockets with `SO_REUSEPORT` set can share an
736/// address).
737#[derive(Derivative)]
738#[derivative(Default(bound = ""))]
739pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
740    addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
741}
742
743impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
744    BoundSocketMap<I, D, A, S>
745{
746    /// Returns the number of entries in the map.
747    pub fn len(&self) -> usize {
748        self.addr_to_state.len()
749    }
750}
751
752/// Uninstantiable tag type for denoting listening sockets.
753pub enum Listener {}
754/// Uninstantiable tag type for denoting connected sockets.
755pub enum Connection {}
756
757/// View struct over one type of sockets in a [`BoundSocketMap`].
758pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
759
760impl<
761    'a,
762    I: Ip,
763    D: DeviceIdentifier,
764    SocketType: ConvertSocketMapState<I, D, A, S>,
765    A: SocketMapAddrSpec,
766    S: SocketMapStateSpec,
767> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
768where
769    S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
770{
771    /// Returns the state at an address, if there is any.
772    pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
773        let Self(addr_to_state, _marker) = self;
774        addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
775            SocketType::from_bound_ref(state)
776                .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
777        })
778    }
779
780    /// Returns `Ok(())` if a socket could be inserted, otherwise an error.
781    ///
782    /// Goes through a dry run of inserting a socket at the given address and
783    /// with the given sharing state, returning `Ok(())` if the insertion would
784    /// succeed, otherwise the error that would be returned.
785    pub fn could_insert(
786        self,
787        addr: &SocketType::Addr,
788        sharing: &SocketType::SharingState,
789    ) -> Result<(), InsertError> {
790        let Self(addr_to_state, _) = self;
791        match self.get_by_addr(addr) {
792            Some(state) => {
793                state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
794            }
795            None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
796        }
797    }
798}
799
800/// A borrowed state entry in a [`SocketMap`].
801#[derive(Derivative)]
802#[derivative(Debug(bound = ""))]
803pub struct SocketStateEntry<
804    'a,
805    I: Ip,
806    D: DeviceIdentifier,
807    A: SocketMapAddrSpec,
808    S: SocketMapStateSpec,
809    SocketType,
810> {
811    id: SocketId<S>,
812    addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
813    _marker: PhantomData<SocketType>,
814}
815
816impl<
817    'a,
818    I: Ip,
819    D: DeviceIdentifier,
820    SocketType: ConvertSocketMapState<I, D, A, S>,
821    A: SocketMapAddrSpec,
822    S: SocketMapStateSpec
823        + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
824> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
825where
826    SocketType::SharingState: Clone,
827    SocketType::Id: Clone,
828{
829    /// Attempts to insert a new entry into the [`SocketMap`] backing this
830    /// `Sockets`.
831    pub fn try_insert(
832        self,
833        socket_addr: SocketType::Addr,
834        tag_state: SocketType::SharingState,
835        id: SocketType::Id,
836    ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, (InsertError, SocketType::SharingState)>
837    {
838        self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
839            .map(|(entry, ())| entry)
840    }
841
842    /// Like [`Sockets::try_insert`] but calls `make_id` to create a socket ID
843    /// before inserting into the map.
844    ///
845    /// `make_id` returns type `R` that is returned to the caller on success.
846    pub fn try_insert_with<R>(
847        self,
848        socket_addr: SocketType::Addr,
849        tag_state: SocketType::SharingState,
850        make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
851    ) -> Result<
852        (SocketStateEntry<'a, I, D, A, S, SocketType>, R),
853        (InsertError, SocketType::SharingState),
854    > {
855        let Self(addr_to_state, _) = self;
856        match S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state) {
857            Err(e) => return Err((e, tag_state)),
858            Ok(()) => (),
859        };
860
861        let addr = SocketType::to_addr_vec(&socket_addr);
862
863        match addr_to_state.entry(addr) {
864            Entry::Occupied(mut o) => {
865                let (id, ret) = o.map_mut(|bound| {
866                    let bound = match SocketType::from_bound_mut(bound) {
867                        Some(bound) => bound,
868                        None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
869                    };
870                    match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
871                        bound, &tag_state,
872                    ) {
873                        Ok(v) => {
874                            let (id, ret) = make_id(socket_addr, tag_state);
875                            v.insert(id.clone());
876                            Ok((SocketType::to_socket_id(id), ret))
877                        }
878                        Err(IncompatibleError) => Err((InsertError::Exists, tag_state)),
879                    }
880                })?;
881                Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
882            }
883            Entry::Vacant(v) => {
884                let (id, ret) = make_id(socket_addr, tag_state.clone());
885                let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
886                    &tag_state,
887                    id.clone(),
888                )));
889                let id = SocketType::to_socket_id(id);
890                Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
891            }
892        }
893    }
894
895    /// Returns a borrowed entry at `id` and `addr`.
896    pub fn entry(
897        self,
898        id: &SocketType::Id,
899        addr: &SocketType::Addr,
900    ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
901        let Self(addr_to_state, _) = self;
902        let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
903            Entry::Vacant(_) => return None,
904            Entry::Occupied(o) => o,
905        };
906        let state = SocketType::from_bound_ref(addr_entry.get())?;
907
908        state.contains_id(id).then_some(SocketStateEntry {
909            id: SocketType::to_socket_id(id.clone()),
910            addr_entry,
911            _marker: PhantomData::default(),
912        })
913    }
914
915    /// Removes the entry with `id` and `addr`.
916    pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
917        self.entry(id, addr)
918            .map(|entry| {
919                entry.remove();
920            })
921            .ok_or(NotFoundError)
922    }
923}
924
925/// The error returned when updating the sharing state for a [`SocketMap`] entry
926/// fails.
927#[derive(Debug)]
928pub struct UpdateSharingError;
929
930impl<
931    'a,
932    I: Ip,
933    D: DeviceIdentifier,
934    SocketType: ConvertSocketMapState<I, D, A, S>,
935    A: SocketMapAddrSpec,
936    S: SocketMapStateSpec,
937> SocketStateEntry<'a, I, D, A, S, SocketType>
938where
939    SocketType::Id: Clone,
940{
941    /// Returns this entry's address.
942    pub fn get_addr(&self) -> &SocketType::Addr {
943        let Self { id: _, addr_entry, _marker } = self;
944        SocketType::from_addr_vec_ref(addr_entry.key())
945    }
946
947    /// Returns this entry's identifier.
948    pub fn id(&self) -> &SocketType::Id {
949        let Self { id, addr_entry: _, _marker } = self;
950        SocketType::from_socket_id_ref(id)
951    }
952
953    /// Attempts to update the address for this entry.
954    pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
955        let Self { id, addr_entry, _marker } = self;
956
957        let new_addrvec = SocketType::to_addr_vec(&new_addr);
958        let old_addr = addr_entry.key().clone();
959        let (addr_state, addr_to_state) = addr_entry.remove_from_map();
960        let addr_to_state = match addr_to_state.entry(new_addrvec) {
961            Entry::Occupied(o) => o.into_map(),
962            Entry::Vacant(v) => {
963                if v.descendant_counts().len() != 0 {
964                    v.into_map()
965                } else {
966                    let new_addr_entry = v.insert(addr_state);
967                    return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
968                }
969            }
970        };
971        let to_restore = addr_state;
972        // Restore the old state before returning an error.
973        let addr_entry = match addr_to_state.entry(old_addr) {
974            Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
975            Entry::Vacant(v) => v.insert(to_restore),
976        };
977        return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
978    }
979
980    /// Removes this entry from the map.
981    pub fn remove(self) {
982        let Self { id, mut addr_entry, _marker } = self;
983        let addr = addr_entry.key().clone();
984        match addr_entry.map_mut(|value| {
985            let value = match SocketType::from_bound_mut(value) {
986                Some(value) => value,
987                None => unreachable!("found {:?} for address {:?}", value, addr),
988            };
989            value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
990        }) {
991            RemoveResult::Success => (),
992            RemoveResult::IsLast => {
993                let _: Bound<S> = addr_entry.remove();
994            }
995        }
996    }
997
998    /// Attempts to update the sharing state for this entry.
999    pub fn try_update_sharing(
1000        &mut self,
1001        old_sharing_state: &SocketType::SharingState,
1002        new_sharing_state: SocketType::SharingState,
1003    ) -> Result<(), UpdateSharingError>
1004    where
1005        SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1006        S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1007    {
1008        let Self { id, addr_entry, _marker } = self;
1009        let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1010
1011        S::allows_sharing_update(
1012            addr_entry.get_map(),
1013            addr,
1014            old_sharing_state,
1015            &new_sharing_state,
1016        )?;
1017
1018        addr_entry
1019            .map_mut(|value| {
1020                let value = match SocketType::from_bound_mut(value) {
1021                    Some(value) => value,
1022                    // We shouldn't ever be storing listener state in a bound
1023                    // address, or bound state in a listener address. Doing so means
1024                    // we've got a serious bug.
1025                    None => unreachable!("found invalid state {:?}", value),
1026                };
1027
1028                value.try_update_sharing(
1029                    SocketType::from_socket_id_ref(id).clone(),
1030                    &new_sharing_state,
1031                )
1032            })
1033            .map_err(|IncompatibleError| UpdateSharingError)
1034    }
1035}
1036
1037impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1038where
1039    AddrVec<I, D, A>: IterShadows,
1040    S: SocketMapStateSpec,
1041{
1042    /// Returns an iterator over the listeners on the socket map.
1043    pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1044    where
1045        S: SocketMapConflictPolicy<
1046                ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1047                <S as SocketMapStateSpec>::ListenerSharingState,
1048                I,
1049                D,
1050                A,
1051            >,
1052        S::ListenerAddrState:
1053            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1054    {
1055        let Self { addr_to_state } = self;
1056        Sockets(addr_to_state, Default::default())
1057    }
1058
1059    /// Returns a mutable iterator over the listeners on the socket map.
1060    pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1061    where
1062        S: SocketMapConflictPolicy<
1063                ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1064                <S as SocketMapStateSpec>::ListenerSharingState,
1065                I,
1066                D,
1067                A,
1068            >,
1069        S::ListenerAddrState:
1070            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1071    {
1072        let Self { addr_to_state } = self;
1073        Sockets(addr_to_state, Default::default())
1074    }
1075
1076    /// Returns an iterator over the connections on the socket map.
1077    pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1078    where
1079        S: SocketMapConflictPolicy<
1080                ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1081                <S as SocketMapStateSpec>::ConnSharingState,
1082                I,
1083                D,
1084                A,
1085            >,
1086        S::ConnAddrState:
1087            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1088    {
1089        let Self { addr_to_state } = self;
1090        Sockets(addr_to_state, Default::default())
1091    }
1092
1093    /// Returns a mutable iterator over the connections on the socket map.
1094    pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1095    where
1096        S: SocketMapConflictPolicy<
1097                ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1098                <S as SocketMapStateSpec>::ConnSharingState,
1099                I,
1100                D,
1101                A,
1102            >,
1103        S::ConnAddrState:
1104            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1105    {
1106        let Self { addr_to_state } = self;
1107        Sockets(addr_to_state, Default::default())
1108    }
1109
1110    #[cfg(test)]
1111    pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1112        let Self { addr_to_state } = self;
1113        addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1114    }
1115
1116    /// Gets the number of shadower entries for `addr`.
1117    pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1118        let Self { addr_to_state } = self;
1119        addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1120    }
1121}
1122
1123/// The type returned by [`BoundSocketMap::iter_receivers`].
1124pub enum FoundSockets<A, It> {
1125    /// A single recipient was found for the address.
1126    Single(A),
1127    /// Indicates the looked-up address was multicast, and holds an iterator of
1128    /// the found receivers.
1129    Multicast(It),
1130}
1131
1132/// A borrowed entry in a [`BoundSocketMap`].
1133#[allow(missing_docs)]
1134#[derive(Debug)]
1135pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1136    Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1137    Conn(
1138        &'a S::ConnAddrState,
1139        ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1140    ),
1141}
1142
1143impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1144where
1145    I: BroadcastIpExt<Addr: MulticastAddress>,
1146    D: DeviceIdentifier,
1147    A: SocketMapAddrSpec,
1148    S: SocketMapStateSpec
1149        + SocketMapConflictPolicy<
1150            ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1151            <S as SocketMapStateSpec>::ListenerSharingState,
1152            I,
1153            D,
1154            A,
1155        > + SocketMapConflictPolicy<
1156            ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1157            <S as SocketMapStateSpec>::ConnSharingState,
1158            I,
1159            D,
1160            A,
1161        >,
1162{
1163    /// Looks up a connected socket.
1164    ///
1165    /// This is a lightweight version of `iter_receivers()` that doesn't try to
1166    /// lookup listening sockets. It is used for early demux, which applies only
1167    /// to connected sockets.
1168    pub fn lookup_connected(
1169        &self,
1170        (src_ip, src_port): (SocketIpAddr<I::Addr>, A::RemoteIdentifier),
1171        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1172        device: D,
1173    ) -> Option<&'_ S::ConnAddrState> {
1174        let mut addr = ConnAddr {
1175            ip: ConnIpAddr { local: (dst_ip, dst_port), remote: (src_ip, src_port) },
1176            device: Some(device),
1177        };
1178        let entry = self.conns().get_by_addr(&addr);
1179        if entry.is_some() {
1180            return entry;
1181        }
1182        addr.device = None;
1183        self.conns().get_by_addr(&addr)
1184    }
1185
1186    /// Finds the socket(s) that should receive an incoming packet.
1187    ///
1188    /// Uses the provided addresses and receiving device to look up sockets that
1189    /// should receive a matching incoming packet. Returns `None` if no sockets
1190    /// were found, or the results of the lookup.
1191    pub fn iter_receivers(
1192        &self,
1193        (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1194        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1195        device: D,
1196        broadcast: Option<I::BroadcastMarker>,
1197    ) -> Option<
1198        FoundSockets<
1199            AddrEntry<'_, I, D, A, S>,
1200            impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1201        >,
1202    > {
1203        let mut matching_entries = AddrVecIter::with_device(
1204            match (src_ip, src_port) {
1205                (Some(specified_src_ip), Some(src_port)) => {
1206                    ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1207                        .into()
1208                }
1209                _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1210            },
1211            device,
1212        )
1213        .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1214            AddrVec::Listen(l) => {
1215                self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1216            }
1217            AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1218        });
1219
1220        if broadcast.is_some() || dst_ip.addr().is_multicast() {
1221            Some(FoundSockets::Multicast(matching_entries))
1222        } else {
1223            let single_entry: Option<_> = matching_entries.next();
1224            single_entry.map(FoundSockets::Single)
1225        }
1226    }
1227}
1228
1229/// Errors observed by [`SocketMapConflictPolicy`].
1230#[derive(Debug, Eq, PartialEq)]
1231pub enum InsertError {
1232    /// A shadow address exists for the entry.
1233    ShadowAddrExists,
1234    /// Entry already exists.
1235    Exists,
1236    /// A shadower exists for the entry.
1237    ShadowerExists,
1238    /// An indirect conflict was detected.
1239    IndirectConflict,
1240}
1241
1242/// Helper trait for converting between [`AddrVec`] and [`Bound`] and their
1243/// variants.
1244pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1245    type Id;
1246    type SharingState;
1247    type Addr: Debug;
1248    type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1249
1250    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1251    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1252    fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1253    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1254    fn to_bound(state: Self::AddrState) -> Bound<S>;
1255    fn to_socket_id(id: Self::Id) -> SocketId<S>;
1256    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1257}
1258
1259impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1260    ConvertSocketMapState<I, D, A, S> for Listener
1261{
1262    type Id = S::ListenerId;
1263    type SharingState = S::ListenerSharingState;
1264    type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1265    type AddrState = S::ListenerAddrState;
1266    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1267        AddrVec::Listen(addr.clone())
1268    }
1269
1270    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1271        match addr {
1272            AddrVec::Listen(l) => l,
1273            AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1274        }
1275    }
1276
1277    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1278        match bound {
1279            Bound::Listen(l) => Some(l),
1280            Bound::Conn(_c) => None,
1281        }
1282    }
1283
1284    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1285        match bound {
1286            Bound::Listen(l) => Some(l),
1287            Bound::Conn(_c) => None,
1288        }
1289    }
1290
1291    fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1292        Bound::Listen(state)
1293    }
1294    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1295        match id {
1296            SocketId::Listener(id) => id,
1297            SocketId::Connection(_) => unreachable!("connection ID for listener"),
1298        }
1299    }
1300    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1301        SocketId::Listener(id)
1302    }
1303}
1304
1305impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1306    ConvertSocketMapState<I, D, A, S> for Connection
1307{
1308    type Id = S::ConnId;
1309    type SharingState = S::ConnSharingState;
1310    type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1311    type AddrState = S::ConnAddrState;
1312    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1313        AddrVec::Conn(addr.clone())
1314    }
1315
1316    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1317        match addr {
1318            AddrVec::Conn(c) => c,
1319            AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1320        }
1321    }
1322
1323    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1324        match bound {
1325            Bound::Listen(_l) => None,
1326            Bound::Conn(c) => Some(c),
1327        }
1328    }
1329
1330    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1331        match bound {
1332            Bound::Listen(_l) => None,
1333            Bound::Conn(c) => Some(c),
1334        }
1335    }
1336
1337    fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1338        Bound::Conn(state)
1339    }
1340
1341    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1342        match id {
1343            SocketId::Connection(id) => id,
1344            SocketId::Listener(_) => unreachable!("listener ID for connection"),
1345        }
1346    }
1347    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1348        SocketId::Connection(id)
1349    }
1350}
1351
1352/// An identifier of a sharing domain used for SO_REUSEPORT.
1353#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)]
1354pub struct SharingDomain(u64);
1355
1356impl SharingDomain {
1357    /// Creates a new instance with the specified ID. Caller must ensure that the `id`
1358    /// uniquely identifies the sharing domain and that the client is authorized to use it,
1359    /// e.g. on Fuchsia the ID is the KOID of a handle provided by the client.
1360    pub const fn new(id: u64) -> Self {
1361        SharingDomain(id)
1362    }
1363}
1364
1365/// A value of the SO_REUSEPORT option. Also encodes the sharing domain, which allows
1366/// to ensure that only sockets in the same domain can share ports.
1367#[derive(Default, Debug, Eq, PartialEq, Clone, Copy, Hash)]
1368pub enum ReusePortOption {
1369    /// The option is disabled.
1370    #[default]
1371    Disabled,
1372
1373    /// The option is enabled: the port is shareable with other sockets in the
1374    /// same sharing domain.
1375    Enabled(SharingDomain),
1376}
1377
1378impl ReusePortOption {
1379    /// Returns `true` if the option is enabled.
1380    pub fn is_enabled(&self) -> bool {
1381        matches!(self, ReusePortOption::Enabled(_))
1382    }
1383
1384    /// Returns `true` if the socket is shareable with a socket with the
1385    /// specified value of the SO_REUSEPORT option.
1386    pub fn is_shareable_with(&self, other: &Self) -> bool {
1387        match (self, other) {
1388            (ReusePortOption::Enabled(domain1), ReusePortOption::Enabled(domain2)) => {
1389                domain1 == domain2
1390            }
1391            _ => false,
1392        }
1393    }
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398    use alloc::vec;
1399    use alloc::vec::Vec;
1400
1401    use assert_matches::assert_matches;
1402    use net_declare::{net_ip_v4, net_ip_v6};
1403    use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1404    use netstack3_hashmap::HashSet;
1405    use test_case::test_case;
1406
1407    use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1408    use crate::testutil::set_logger_for_test;
1409
1410    use super::*;
1411
1412    #[test_case(net_ip_v4!("8.8.8.8"))]
1413    #[test_case(net_ip_v4!("127.0.0.1"))]
1414    #[test_case(net_ip_v4!("127.0.8.9"))]
1415    #[test_case(net_ip_v4!("224.1.2.3"))]
1416    fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1417        // No IPv4 addresses are allowed to have a zone.
1418        let addr = SpecifiedAddr::new(addr).unwrap();
1419        assert_eq!(addr.must_have_zone(), false);
1420    }
1421
1422    #[test_case(net_ip_v6!("1::2:3"), false)]
1423    #[test_case(net_ip_v6!("::1"), false; "localhost")]
1424    #[test_case(net_ip_v6!("1::"), false)]
1425    #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1426    #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1427    #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1428    #[test_case(net_ip_v6!("fe80::1"), true)]
1429    fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1430        // Only link-local unicast and multicast addresses are allowed to have
1431        // zones.
1432        let addr = SpecifiedAddr::new(addr).unwrap();
1433        assert_eq!(addr.must_have_zone(), must_have);
1434    }
1435
1436    #[test]
1437    fn try_into_null_zoned_ipv6() {
1438        assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1439        let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1440        const ZONE: u32 = 5;
1441        assert_eq!(
1442            zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1443            Some(AddrAndZone::new(zoned, ZONE).unwrap())
1444        );
1445    }
1446
1447    enum FakeSpec {}
1448
1449    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1450    struct Listener(usize);
1451
1452    #[derive(PartialEq, Eq, Debug, Copy, Clone)]
1453    struct SharingState {
1454        tag: char,
1455        shared: bool,
1456    }
1457
1458    impl SharingState {
1459        fn exclusive(tag: char) -> Self {
1460            Self { tag, shared: false }
1461        }
1462
1463        fn shared(tag: char) -> Self {
1464            Self { tag, shared: true }
1465        }
1466    }
1467
1468    impl SharingState {
1469        fn can_share_with(&self, other: &Self) -> bool {
1470            self.tag == other.tag && self.shared && other.shared
1471        }
1472    }
1473
1474    #[derive(PartialEq, Eq, Debug)]
1475    struct Multiple<T> {
1476        sharing_state: SharingState,
1477        entries: Vec<T>,
1478    }
1479
1480    impl<T> Multiple<T> {
1481        fn new_exclusive(tag: char, entries: Vec<T>) -> Self {
1482            Self { sharing_state: SharingState { tag, shared: false }, entries }
1483        }
1484    }
1485
1486    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1487    struct Conn(usize);
1488
1489    enum FakeAddrSpec {}
1490
1491    impl SocketMapAddrSpec for FakeAddrSpec {
1492        type LocalIdentifier = NonZeroU16;
1493        type RemoteIdentifier = ();
1494    }
1495
1496    impl SocketMapStateSpec for FakeSpec {
1497        type AddrVecTag = SharingState;
1498
1499        type ListenerId = Listener;
1500        type ConnId = Conn;
1501
1502        type ListenerSharingState = SharingState;
1503        type ConnSharingState = SharingState;
1504
1505        type ListenerAddrState = Multiple<Listener>;
1506        type ConnAddrState = Multiple<Conn>;
1507
1508        fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1509            state.sharing_state
1510        }
1511
1512        fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1513            state.sharing_state
1514        }
1515    }
1516
1517    type FakeBoundSocketMap =
1518        BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1519
1520    /// Generator for unique socket IDs that don't have any state.
1521    ///
1522    /// Calling [`FakeSocketIdGen::next`] returns a unique ID.
1523    #[derive(Default)]
1524    struct FakeSocketIdGen {
1525        next_id: usize,
1526    }
1527
1528    impl FakeSocketIdGen {
1529        fn next(&mut self) -> usize {
1530            let next_next_id = self.next_id + 1;
1531            core::mem::replace(&mut self.next_id, next_next_id)
1532        }
1533    }
1534
1535    impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1536        type Id = I;
1537        type SharingState = SharingState;
1538        type Inserter<'a>
1539            = &'a mut Vec<I>
1540        where
1541            I: 'a;
1542
1543        fn new(sharing_state: &SharingState, id: I) -> Self {
1544            Self { sharing_state: *sharing_state, entries: vec![id] }
1545        }
1546
1547        fn contains_id(&self, id: &Self::Id) -> bool {
1548            self.entries.contains(id)
1549        }
1550
1551        fn try_get_inserter<'a, 'b>(
1552            &'b mut self,
1553            new_sharing_state: &'a SharingState,
1554        ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1555            (self.sharing_state == *new_sharing_state)
1556                .then_some(&mut self.entries)
1557                .ok_or(IncompatibleError)
1558        }
1559
1560        fn could_insert(&self, new_sharing_state: &SharingState) -> Result<(), IncompatibleError> {
1561            (self.sharing_state == *new_sharing_state).then_some(()).ok_or(IncompatibleError)
1562        }
1563
1564        fn remove_by_id(&mut self, id: I) -> RemoveResult {
1565            let index = self.entries.iter().position(|i| i == &id).expect("did not find id");
1566            let _: I = self.entries.swap_remove(index);
1567            if self.entries.is_empty() { RemoveResult::IsLast } else { RemoveResult::Success }
1568        }
1569    }
1570
1571    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1572        SocketMapConflictPolicy<A, SharingState, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1573        for FakeSpec
1574    {
1575        fn check_insert_conflicts(
1576            new_sharing_state: &SharingState,
1577            addr: &A,
1578            socketmap: &SocketMap<
1579                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1580                Bound<FakeSpec>,
1581            >,
1582        ) -> Result<(), InsertError> {
1583            let dest: AddrVec<_, _, _> = addr.clone().into();
1584            if dest.iter_shadows().any(|a| {
1585                let entry = socketmap.get(&a);
1586                match entry {
1587                    Some(Bound::Listen(Multiple { sharing_state, .. }))
1588                    | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1589                        !sharing_state.can_share_with(new_sharing_state)
1590                    }
1591                    None => false,
1592                }
1593            }) {
1594                return Err(InsertError::ShadowAddrExists);
1595            }
1596
1597            match socketmap.get(&dest) {
1598                Some(Bound::Listen(Multiple { sharing_state, .. }))
1599                | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1600                    // Require that all sockets inserted in a `Multiple` entry
1601                    // have the same sharing state.
1602                    if sharing_state != new_sharing_state {
1603                        return Err(InsertError::Exists);
1604                    }
1605                }
1606                None => (),
1607            }
1608
1609            if socketmap
1610                .descendant_counts(&dest)
1611                .any(|(sharing_state, _count)| !sharing_state.can_share_with(new_sharing_state))
1612            {
1613                Err(InsertError::ShadowerExists)
1614            } else {
1615                Ok(())
1616            }
1617        }
1618    }
1619
1620    impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1621        fn try_update_sharing(
1622            &mut self,
1623            id: Self::Id,
1624            new_sharing_state: &Self::SharingState,
1625        ) -> Result<(), IncompatibleError> {
1626            if self.sharing_state == *new_sharing_state {
1627                return Ok(());
1628            }
1629
1630            // Preserve the invariant that all sockets inserted in a `Multiple`
1631            // entry have the same sharing state. That means we can't change
1632            // the sharing state of all the sockets at the address unless there
1633            // is exactly one!
1634            if self.entries.len() != 1 {
1635                return Err(IncompatibleError);
1636            }
1637            assert!(self.entries.contains(&id));
1638            self.sharing_state = *new_sharing_state;
1639            Ok(())
1640        }
1641    }
1642
1643    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1644        SocketMapUpdateSharingPolicy<
1645            A,
1646            SharingState,
1647            Ipv4,
1648            FakeWeakDeviceId<FakeDeviceId>,
1649            FakeAddrSpec,
1650        > for FakeSpec
1651    {
1652        fn allows_sharing_update(
1653            _socketmap: &SocketMap<
1654                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1655                Bound<Self>,
1656            >,
1657            _addr: &A,
1658            _old_sharing: &SharingState,
1659            _new_sharing_state: &SharingState,
1660        ) -> Result<(), UpdateSharingError> {
1661            Ok(())
1662        }
1663    }
1664
1665    const LISTENER_ADDR: ListenerAddr<
1666        ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1667        FakeWeakDeviceId<FakeDeviceId>,
1668    > = ListenerAddr {
1669        ip: ListenerIpAddr {
1670            addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1671            identifier: NonZeroU16::new(1).unwrap(),
1672        },
1673        device: None,
1674    };
1675
1676    const CONN_ADDR: ConnAddr<
1677        ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1678        FakeWeakDeviceId<FakeDeviceId>,
1679    > = ConnAddr {
1680        ip: ConnIpAddr {
1681            local: (
1682                unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1683                NonZeroU16::new(1).unwrap(),
1684            ),
1685            remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1686        },
1687        device: None,
1688    };
1689
1690    #[test]
1691    fn bound_insert_get_remove_listener() {
1692        set_logger_for_test();
1693        let mut bound = FakeBoundSocketMap::default();
1694        let mut fake_id_gen = FakeSocketIdGen::default();
1695
1696        let addr = LISTENER_ADDR;
1697
1698        let id = {
1699            let entry = bound
1700                .listeners_mut()
1701                .try_insert(addr, SharingState::exclusive('v'), Listener(fake_id_gen.next()))
1702                .unwrap();
1703            assert_eq!(entry.get_addr(), &addr);
1704            entry.id().clone()
1705        };
1706
1707        assert_eq!(
1708            bound.listeners().get_by_addr(&addr),
1709            Some(&Multiple::new_exclusive('v', vec![id]))
1710        );
1711
1712        assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1713        assert_eq!(bound.listeners().get_by_addr(&addr), None);
1714    }
1715
1716    #[test]
1717    fn bound_insert_get_remove_conn() {
1718        set_logger_for_test();
1719        let mut bound = FakeBoundSocketMap::default();
1720        let mut fake_id_gen = FakeSocketIdGen::default();
1721
1722        let addr = CONN_ADDR;
1723
1724        let id = {
1725            let entry = bound
1726                .conns_mut()
1727                .try_insert(addr, SharingState::exclusive('v'), Conn(fake_id_gen.next()))
1728                .unwrap();
1729            assert_eq!(entry.get_addr(), &addr);
1730            entry.id().clone()
1731        };
1732
1733        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('v', vec![id])));
1734
1735        assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1736        assert_eq!(bound.conns().get_by_addr(&addr), None);
1737    }
1738
1739    #[test]
1740    fn bound_iter_addrs() {
1741        set_logger_for_test();
1742        let mut bound = FakeBoundSocketMap::default();
1743        let mut fake_id_gen = FakeSocketIdGen::default();
1744
1745        let listener_addrs = [
1746            (Some(net_ip_v4!("1.1.1.1")), 1),
1747            (Some(net_ip_v4!("2.2.2.2")), 2),
1748            (Some(net_ip_v4!("1.1.1.1")), 3),
1749            (None, 4),
1750        ]
1751        .map(|(ip, identifier)| ListenerAddr {
1752            device: None,
1753            ip: ListenerIpAddr {
1754                addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1755                identifier: NonZeroU16::new(identifier).unwrap(),
1756            },
1757        });
1758        let conn_addrs = [
1759            (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1760            (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1761        ]
1762        .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1763            ip: ConnIpAddr {
1764                local: (
1765                    SocketIpAddr::new(local_ip).unwrap(),
1766                    NonZeroU16::new(local_identifier).unwrap(),
1767                ),
1768                remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1769            },
1770            device: None,
1771        });
1772
1773        for addr in listener_addrs.iter().cloned() {
1774            let _entry = bound
1775                .listeners_mut()
1776                .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1777                .unwrap();
1778        }
1779        for addr in conn_addrs.iter().cloned() {
1780            let _entry = bound
1781                .conns_mut()
1782                .try_insert(addr, SharingState::exclusive('a'), Conn(fake_id_gen.next()))
1783                .unwrap();
1784        }
1785        let expected_addrs = listener_addrs
1786            .into_iter()
1787            .map(Into::into)
1788            .chain(conn_addrs.into_iter().map(Into::into))
1789            .collect::<HashSet<_>>();
1790
1791        assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1792    }
1793
1794    #[test]
1795    fn try_insert_with_callback_not_called_on_error() {
1796        // TODO(https://fxbug.dev/42076891): remove this test along with
1797        // try_insert_with.
1798        set_logger_for_test();
1799        let mut bound = FakeBoundSocketMap::default();
1800        let addr = LISTENER_ADDR;
1801
1802        // Insert a listener so that future calls can conflict.
1803        let _: &Listener = bound
1804            .listeners_mut()
1805            .try_insert(addr, SharingState::exclusive('a'), Listener(0))
1806            .unwrap()
1807            .id();
1808
1809        // All of the below try_insert_with calls should fail, but more
1810        // importantly, they should not call the `make_id` callback (because it
1811        // is only called once success is certain).
1812        fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1813            panic!("should never be called");
1814        }
1815
1816        assert_matches!(
1817            bound.listeners_mut().try_insert_with(
1818                addr,
1819                SharingState::exclusive('b'),
1820                is_never_called
1821            ),
1822            Err((InsertError::Exists, SharingState { .. }))
1823        );
1824        assert_matches!(
1825            bound.listeners_mut().try_insert_with(
1826                ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1827                SharingState::exclusive('b'),
1828                is_never_called
1829            ),
1830            Err((InsertError::ShadowAddrExists, _))
1831        );
1832        assert_matches!(
1833            bound.conns_mut().try_insert_with(
1834                ConnAddr {
1835                    device: None,
1836                    ip: ConnIpAddr {
1837                        local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1838                        remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1839                    },
1840                },
1841                SharingState::exclusive('b'),
1842                is_never_called,
1843            ),
1844            Err((InsertError::ShadowAddrExists, _))
1845        );
1846    }
1847
1848    #[test]
1849    fn insert_listener_conflict_with_listener() {
1850        set_logger_for_test();
1851        let mut bound = FakeBoundSocketMap::default();
1852        let mut fake_id_gen = FakeSocketIdGen::default();
1853        let addr = LISTENER_ADDR;
1854
1855        let _: &Listener = bound
1856            .listeners_mut()
1857            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1858            .unwrap()
1859            .id();
1860        assert_matches!(
1861            bound.listeners_mut().try_insert(
1862                addr,
1863                SharingState::exclusive('b'),
1864                Listener(fake_id_gen.next())
1865            ),
1866            Err((InsertError::Exists, SharingState { tag: 'b', .. }))
1867        );
1868    }
1869
1870    #[test]
1871    fn insert_listener_conflict_with_shadower() {
1872        set_logger_for_test();
1873        let mut bound = FakeBoundSocketMap::default();
1874        let mut fake_id_gen = FakeSocketIdGen::default();
1875        let addr = LISTENER_ADDR;
1876        let shadows_addr = {
1877            assert_eq!(addr.device, None);
1878            ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1879        };
1880
1881        let _: &Listener = bound
1882            .listeners_mut()
1883            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1884            .unwrap()
1885            .id();
1886        assert_matches!(
1887            bound.listeners_mut().try_insert(
1888                shadows_addr,
1889                SharingState::exclusive('b'),
1890                Listener(fake_id_gen.next())
1891            ),
1892            Err((InsertError::ShadowAddrExists, SharingState { tag: 'b', .. }))
1893        );
1894    }
1895
1896    #[test]
1897    fn insert_conn_conflict_with_listener() {
1898        set_logger_for_test();
1899        let mut bound = FakeBoundSocketMap::default();
1900        let mut fake_id_gen = FakeSocketIdGen::default();
1901        let addr = LISTENER_ADDR;
1902        let shadows_addr = ConnAddr {
1903            device: None,
1904            ip: ConnIpAddr {
1905                local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1906                remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1907            },
1908        };
1909
1910        let _: &Listener = bound
1911            .listeners_mut()
1912            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1913            .unwrap()
1914            .id();
1915        assert_matches!(
1916            bound.conns_mut().try_insert(
1917                shadows_addr,
1918                SharingState::exclusive('b'),
1919                Conn(fake_id_gen.next())
1920            ),
1921            Err((InsertError::ShadowAddrExists, SharingState { tag: 'b', .. }))
1922        );
1923    }
1924
1925    #[test]
1926    fn insert_and_remove_listener() {
1927        set_logger_for_test();
1928        let mut bound = FakeBoundSocketMap::default();
1929        let mut fake_id_gen = FakeSocketIdGen::default();
1930        let addr = LISTENER_ADDR;
1931
1932        let a = bound
1933            .listeners_mut()
1934            .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1935            .unwrap()
1936            .id()
1937            .clone();
1938        let b = bound
1939            .listeners_mut()
1940            .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1941            .unwrap()
1942            .id()
1943            .clone();
1944        assert_ne!(a, b);
1945
1946        assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
1947        assert_eq!(
1948            bound.listeners().get_by_addr(&addr),
1949            Some(&Multiple::new_exclusive('x', vec![b]))
1950        );
1951    }
1952
1953    #[test]
1954    fn insert_and_remove_conn() {
1955        set_logger_for_test();
1956        let mut bound = FakeBoundSocketMap::default();
1957        let mut fake_id_gen = FakeSocketIdGen::default();
1958        let addr = CONN_ADDR;
1959
1960        let a = bound
1961            .conns_mut()
1962            .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1963            .unwrap()
1964            .id()
1965            .clone();
1966        let b = bound
1967            .conns_mut()
1968            .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1969            .unwrap()
1970            .id()
1971            .clone();
1972        assert_ne!(a, b);
1973
1974        assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
1975        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('x', vec![b])));
1976    }
1977
1978    #[test]
1979    fn update_listener_to_shadowed_addr_fails() {
1980        let mut bound = FakeBoundSocketMap::default();
1981        let mut fake_id_gen = FakeSocketIdGen::default();
1982
1983        let first_addr = LISTENER_ADDR;
1984        let second_addr = ListenerAddr {
1985            ip: ListenerIpAddr {
1986                addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
1987                ..LISTENER_ADDR.ip
1988            },
1989            ..LISTENER_ADDR
1990        };
1991        let both_shadow = ListenerAddr {
1992            ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
1993            device: None,
1994        };
1995
1996        let first = bound
1997            .listeners_mut()
1998            .try_insert(first_addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1999            .unwrap()
2000            .id()
2001            .clone();
2002        let second = bound
2003            .listeners_mut()
2004            .try_insert(second_addr, SharingState::exclusive('b'), Listener(fake_id_gen.next()))
2005            .unwrap()
2006            .id()
2007            .clone();
2008
2009        // Moving from (1, "aaa") to (1, None) should fail since it is shadowed
2010        // by (1, "yyy"), and vise versa.
2011        let (ExistsError, entry) = bound
2012            .listeners_mut()
2013            .entry(&second, &second_addr)
2014            .unwrap()
2015            .try_update_addr(both_shadow)
2016            .expect_err("update should fail");
2017
2018        // The entry should correspond to `second`.
2019        assert_eq!(entry.id(), &second);
2020        drop(entry);
2021
2022        let (ExistsError, entry) = bound
2023            .listeners_mut()
2024            .entry(&first, &first_addr)
2025            .unwrap()
2026            .try_update_addr(both_shadow)
2027            .expect_err("update should fail");
2028        assert_eq!(entry.get_addr(), &first_addr);
2029    }
2030
2031    #[test]
2032    fn nonexistent_conn_entry() {
2033        let mut map = FakeBoundSocketMap::default();
2034        let mut fake_id_gen = FakeSocketIdGen::default();
2035        let addr = CONN_ADDR;
2036        let conn_id = map
2037            .conns_mut()
2038            .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2039            .expect("failed to insert")
2040            .id()
2041            .clone();
2042        assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
2043
2044        assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
2045    }
2046
2047    #[test]
2048    fn update_conn_sharing() {
2049        let mut map = FakeBoundSocketMap::default();
2050        let mut fake_id_gen = FakeSocketIdGen::default();
2051        let addr = CONN_ADDR;
2052        let mut entry = map
2053            .conns_mut()
2054            .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2055            .expect("failed to insert");
2056
2057        entry
2058            .try_update_sharing(&SharingState::exclusive('a'), SharingState::exclusive('d'))
2059            .expect("worked");
2060        // Updating sharing is only allowed if there are no other occupants at
2061        // the address.
2062        let mut second_conn = map
2063            .conns_mut()
2064            .try_insert(addr.clone(), SharingState::exclusive('d'), Conn(fake_id_gen.next()))
2065            .expect("can insert");
2066        assert_matches!(
2067            second_conn
2068                .try_update_sharing(&SharingState::exclusive('d'), SharingState::exclusive('e')),
2069            Err(UpdateSharingError)
2070        );
2071    }
2072
2073    #[test]
2074    fn lookup_connected() {
2075        let mut map = FakeBoundSocketMap::default();
2076        let mut fake_id_gen = FakeSocketIdGen::default();
2077
2078        let sharing_state = SharingState::shared('a');
2079
2080        let device_id = FakeWeakDeviceId(FakeDeviceId);
2081        let entry1 = map
2082            .conns_mut()
2083            .try_insert(CONN_ADDR, sharing_state, Conn(fake_id_gen.next()))
2084            .expect("failed to insert")
2085            .id()
2086            .clone();
2087        let conn = map
2088            .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2089            .expect("lookup should succeed");
2090        assert!(conn.contains_id(&entry1));
2091
2092        // Add a second entry with a device ID. This one should be preferred
2093        // over the first one.
2094        let addr_with_device = ConnAddr { device: Some(device_id), ..CONN_ADDR };
2095        let entry2 = map
2096            .conns_mut()
2097            .try_insert(addr_with_device, sharing_state, Conn(fake_id_gen.next()))
2098            .expect("failed to insert")
2099            .id()
2100            .clone();
2101        let conn = map
2102            .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2103            .expect("lookup should succeed");
2104        assert!(conn.contains_id(&entry2));
2105    }
2106}