use core::convert::Infallible as Never;
use core::fmt::Debug;
use core::hash::Hash;
use core::marker::PhantomData;
use core::num::NonZeroU16;
use derivative::Derivative;
use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
use net_types::{
AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
};
use crate::data_structures::socketmap::{
Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
};
use crate::device::{
DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
};
use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
use crate::ip::BroadcastIpExt;
use crate::socket::address::{
AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
};
pub trait DualStackIpExt: Ip {
type OtherVersion: DualStackIpExt<OtherVersion = Self>;
}
impl DualStackIpExt for Ipv4 {
type OtherVersion = Ipv6;
}
impl DualStackIpExt for Ipv6 {
type OtherVersion = Ipv4;
}
pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
this_stack: <T as GenericOverIp<I>>::Type,
other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
_marker: IpVersionMarker<I>,
}
impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
where
T: GenericOverIp<I, Type = T>,
{
Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
}
pub fn into_inner(
self,
) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
let Self { this_stack, other_stack, _marker } = self;
(this_stack, other_stack)
}
pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
self.this_stack
}
pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
&self.this_stack
}
pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
self.other_stack
}
pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
&self.other_stack
}
pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
let Self { this_stack, other_stack, _marker } = self;
DualStackTuple {
this_stack: other_stack,
other_stack: this_stack,
_marker: IpVersionMarker::new(),
}
}
pub fn cast<X>(self) -> DualStackTuple<X, T>
where
X: DualStackIpExt,
T: GenericOverIp<X>
+ GenericOverIp<X::OtherVersion>
+ GenericOverIp<Ipv4>
+ GenericOverIp<Ipv6>,
{
I::map_ip_in(
self,
|v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
|v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
)
}
}
impl<
I: DualStackIpExt,
NewIp: DualStackIpExt,
T: GenericOverIp<NewIp>
+ GenericOverIp<NewIp::OtherVersion>
+ GenericOverIp<I>
+ GenericOverIp<I::OtherVersion>,
> GenericOverIp<NewIp> for DualStackTuple<I, T>
{
type Type = DualStackTuple<NewIp, T>;
}
pub trait SocketIpExt: Ip {
const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
};
}
impl<I: Ip> SocketIpExt for I {}
#[cfg(test)]
mod socket_ip_ext_test {
use super::*;
use ip_test_macro::ip_test;
#[ip_test(I)]
fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
.expect("loopback address should be a valid SocketIpAddr");
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum EitherStack<T, O> {
ThisStack(T),
OtherStack(O),
}
impl<T, O> Clone for EitherStack<T, O>
where
T: Clone,
O: Clone,
{
#[cfg_attr(feature = "instrumented", track_caller)]
fn clone(&self) -> Self {
match self {
Self::ThisStack(t) => Self::ThisStack(t.clone()),
Self::OtherStack(t) => Self::OtherStack(t.clone()),
}
}
}
#[derive(Debug)]
#[allow(missing_docs)]
pub enum MaybeDualStack<DS, NDS> {
DualStack(DS),
NotDualStack(NDS),
}
impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
for MaybeDualStack<DS, NDS>
{
type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
}
#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub enum SetDualStackEnabledError {
SocketIsBound,
NotCapable,
}
#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub struct NotDualStackCapableError;
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct Shutdown {
pub send: bool,
pub receive: bool,
}
#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub enum ShutdownType {
Send,
Receive,
SendAndReceive,
}
impl ShutdownType {
pub fn to_send_receive(&self) -> (bool, bool) {
match self {
Self::Send => (true, false),
Self::Receive => (false, true),
Self::SendAndReceive => (true, true),
}
}
pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
match (send, receive) {
(true, false) => Some(Self::Send),
(false, true) => Some(Self::Receive),
(true, true) => Some(Self::SendAndReceive),
(false, false) => None,
}
}
}
pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
fn must_have_zone(&self) -> bool
where
Self: Copy,
{
self.try_into_null_zoned().is_some()
}
fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
if self.get().is_loopback() {
return None;
}
AddrAndZone::new(self, ())
}
}
impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
pub trait SocketZonedAddrExt<W, A, D> {
fn resolve_addr_with_device(
self,
device: Option<D::Weak>,
) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
where
D: StrongDeviceIdentifier;
}
impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
where
W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
A: IpAddress,
{
fn resolve_addr_with_device(
self,
device: Option<D::Weak>,
) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
where
D: StrongDeviceIdentifier,
{
let (addr, zone) = self.into_addr_zone();
let device = match (zone, device) {
(Some(zone), Some(device)) => {
if device != zone {
return Err(ZonedAddressError::DeviceZoneMismatch);
}
Some(EitherDeviceId::Strong(zone))
}
(Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
(None, Some(device)) => Some(EitherDeviceId::Weak(device)),
(None, None) => {
if addr.as_ref().must_have_zone() {
return Err(ZonedAddressError::RequiredZoneNotProvided);
} else {
None
}
}
};
Ok((addr, device))
}
}
pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
pub local_ip: Option<&'a SpecifiedAddr<A>>,
pub remote_ip: Option<&'a SpecifiedAddr<A>>,
pub old_device: Option<&'a D>,
}
impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
pub fn check_update<N>(
self,
new_device: Option<&N>,
) -> Result<(), SocketDeviceUpdateNotAllowedError>
where
D: PartialEq<N>,
{
let Self { local_ip, remote_ip, old_device } = self;
let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
|| remote_ip.is_some_and(|a| a.must_have_zone());
if !must_have_zone {
return Ok(());
}
let old_device = old_device.unwrap_or_else(|| {
panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
});
if new_device.is_some_and(|new_device| old_device == new_device) {
Ok(())
} else {
Err(SocketDeviceUpdateNotAllowedError)
}
}
}
pub struct SocketDeviceUpdateNotAllowedError;
pub trait SocketMapAddrSpec {
type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
}
pub struct ListenerAddrInfo {
pub has_device: bool,
pub specified_addr: bool,
}
impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
pub(crate) fn info(&self) -> ListenerAddrInfo {
let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
}
}
pub trait SocketMapStateSpec {
type AddrVecTag: Eq + Copy + Debug + 'static;
fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
type ListenerId: Clone + Debug;
type ConnId: Clone + Debug;
type ListenerSharingState: Clone + Debug;
type ConnSharingState: Clone + Debug;
type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
+ Debug;
type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
+ Debug;
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct IncompatibleError;
pub trait Inserter<T> {
fn insert(self, item: T);
}
impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
fn insert(self, item: T) {
self.extend([item])
}
}
impl<T> Inserter<T> for Never {
fn insert(self, _: T) {
match self {}
}
}
pub trait SocketMapAddrStateSpec {
type Id;
type SharingState;
type Inserter<'a>: Inserter<Self::Id> + 'a
where
Self: 'a,
Self::Id: 'a;
fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
fn contains_id(&self, id: &Self::Id) -> bool;
fn try_get_inserter<'a, 'b>(
&'b mut self,
new_sharing_state: &'a Self::SharingState,
) -> Result<Self::Inserter<'b>, IncompatibleError>;
fn could_insert(&self, new_sharing_state: &Self::SharingState)
-> Result<(), IncompatibleError>;
fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
}
pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
fn try_update_sharing(
&mut self,
id: Self::Id,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError>;
}
pub trait SocketMapConflictPolicy<
Addr,
SharingState,
I: Ip,
D: DeviceIdentifier,
A: SocketMapAddrSpec,
>: SocketMapStateSpec
{
fn check_insert_conflicts(
new_sharing_state: &SharingState,
addr: &Addr,
socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
) -> Result<(), InsertError>;
}
pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
SocketMapConflictPolicy<Addr, SharingState, I, D, A>
where
A: SocketMapAddrSpec,
{
fn allows_sharing_update(
socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
addr: &Addr,
old_sharing: &SharingState,
new_sharing: &SharingState,
) -> Result<(), UpdateSharingError>;
}
#[derive(Derivative)]
#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
#[allow(missing_docs)]
pub enum Bound<S: SocketMapStateSpec + ?Sized> {
Listen(S::ListenerAddrState),
Conn(S::ConnAddrState),
}
#[derive(Derivative)]
#[derivative(
Debug(bound = "D: Debug"),
Clone(bound = "D: Clone"),
Eq(bound = "D: Eq"),
PartialEq(bound = "D: PartialEq"),
Hash(bound = "D: Hash")
)]
#[allow(missing_docs)]
pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
}
impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
Tagged<AddrVec<I, D, A>> for Bound<S>
{
type Tag = S::AddrVecTag;
fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
match (self, address) {
(Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
(Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
S::connected_tag(device.is_some(), c)
}
(Bound::Listen(_), AddrVec::Conn(_)) => {
unreachable!("found listen state for conn addr")
}
(Bound::Conn(_), AddrVec::Listen(_)) => {
unreachable!("found conn state for listen addr")
}
}
}
}
impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
type IterShadows = AddrVecIter<I, D, A>;
fn iter_shadows(&self) -> Self::IterShadows {
let (socket_ip_addr, device) = match self.clone() {
AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
};
let mut iter = match device {
Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
None => AddrVecIter::without_device(socket_ip_addr),
};
assert_eq!(iter.next().as_ref(), Some(self));
iter
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
#[allow(missing_docs)]
pub enum SocketAddrType {
AnyListener,
SpecificListener,
Connected,
}
impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
match addr {
Some(_) => SocketAddrType::SpecificListener,
None => SocketAddrType::AnyListener,
}
}
}
impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
SocketAddrType::Connected
}
}
pub enum RemoveResult {
Success,
IsLast,
}
#[derive(Derivative)]
#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
pub enum SocketId<S: SocketMapStateSpec> {
Listener(S::ListenerId),
Connection(S::ConnId),
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
}
impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
BoundSocketMap<I, D, A, S>
{
pub fn len(&self) -> usize {
self.addr_to_state.len()
}
}
pub enum Listener {}
pub enum Connection {}
pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
impl<
'a,
I: Ip,
D: DeviceIdentifier,
SocketType: ConvertSocketMapState<I, D, A, S>,
A: SocketMapAddrSpec,
S: SocketMapStateSpec,
> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
where
S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
{
pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
let Self(addr_to_state, _marker) = self;
addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
SocketType::from_bound_ref(state)
.unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
})
}
pub fn could_insert(
self,
addr: &SocketType::Addr,
sharing: &SocketType::SharingState,
) -> Result<(), InsertError> {
let Self(addr_to_state, _) = self;
match self.get_by_addr(addr) {
Some(state) => {
state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
}
None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
}
}
}
#[derive(Derivative)]
#[derivative(Debug(bound = ""))]
pub struct SocketStateEntry<
'a,
I: Ip,
D: DeviceIdentifier,
A: SocketMapAddrSpec,
S: SocketMapStateSpec,
SocketType,
> {
id: SocketId<S>,
addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
_marker: PhantomData<SocketType>,
}
impl<
'a,
I: Ip,
D: DeviceIdentifier,
SocketType: ConvertSocketMapState<I, D, A, S>,
A: SocketMapAddrSpec,
S: SocketMapStateSpec
+ SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
where
SocketType::SharingState: Clone,
SocketType::Id: Clone,
{
pub fn try_insert(
self,
socket_addr: SocketType::Addr,
tag_state: SocketType::SharingState,
id: SocketType::Id,
) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, (InsertError, SocketType::SharingState)>
{
self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
.map(|(entry, ())| entry)
}
pub fn try_insert_with<R>(
self,
socket_addr: SocketType::Addr,
tag_state: SocketType::SharingState,
make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
) -> Result<
(SocketStateEntry<'a, I, D, A, S, SocketType>, R),
(InsertError, SocketType::SharingState),
> {
let Self(addr_to_state, _) = self;
match S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state) {
Err(e) => return Err((e, tag_state)),
Ok(()) => (),
};
let addr = SocketType::to_addr_vec(&socket_addr);
match addr_to_state.entry(addr) {
Entry::Occupied(mut o) => {
let (id, ret) = o.map_mut(|bound| {
let bound = match SocketType::from_bound_mut(bound) {
Some(bound) => bound,
None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
};
match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
bound, &tag_state,
) {
Ok(v) => {
let (id, ret) = make_id(socket_addr, tag_state);
v.insert(id.clone());
Ok((SocketType::to_socket_id(id), ret))
}
Err(IncompatibleError) => Err((InsertError::Exists, tag_state)),
}
})?;
Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
}
Entry::Vacant(v) => {
let (id, ret) = make_id(socket_addr, tag_state.clone());
let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
&tag_state,
id.clone(),
)));
let id = SocketType::to_socket_id(id);
Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
}
}
}
pub fn entry(
self,
id: &SocketType::Id,
addr: &SocketType::Addr,
) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
let Self(addr_to_state, _) = self;
let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
Entry::Vacant(_) => return None,
Entry::Occupied(o) => o,
};
let state = SocketType::from_bound_ref(addr_entry.get())?;
state.contains_id(id).then_some(SocketStateEntry {
id: SocketType::to_socket_id(id.clone()),
addr_entry,
_marker: PhantomData::default(),
})
}
pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
self.entry(id, addr)
.map(|entry| {
entry.remove();
})
.ok_or(NotFoundError)
}
}
#[derive(Debug)]
pub struct UpdateSharingError;
impl<
'a,
I: Ip,
D: DeviceIdentifier,
SocketType: ConvertSocketMapState<I, D, A, S>,
A: SocketMapAddrSpec,
S: SocketMapStateSpec,
> SocketStateEntry<'a, I, D, A, S, SocketType>
where
SocketType::Id: Clone,
{
pub fn get_addr(&self) -> &SocketType::Addr {
let Self { id: _, addr_entry, _marker } = self;
SocketType::from_addr_vec_ref(addr_entry.key())
}
pub fn id(&self) -> &SocketType::Id {
let Self { id, addr_entry: _, _marker } = self;
SocketType::from_socket_id_ref(id)
}
pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
let Self { id, addr_entry, _marker } = self;
let new_addrvec = SocketType::to_addr_vec(&new_addr);
let old_addr = addr_entry.key().clone();
let (addr_state, addr_to_state) = addr_entry.remove_from_map();
let addr_to_state = match addr_to_state.entry(new_addrvec) {
Entry::Occupied(o) => o.into_map(),
Entry::Vacant(v) => {
if v.descendant_counts().len() != 0 {
v.into_map()
} else {
let new_addr_entry = v.insert(addr_state);
return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
}
}
};
let to_restore = addr_state;
let addr_entry = match addr_to_state.entry(old_addr) {
Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
Entry::Vacant(v) => v.insert(to_restore),
};
return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
}
pub fn remove(self) {
let Self { id, mut addr_entry, _marker } = self;
let addr = addr_entry.key().clone();
match addr_entry.map_mut(|value| {
let value = match SocketType::from_bound_mut(value) {
Some(value) => value,
None => unreachable!("found {:?} for address {:?}", value, addr),
};
value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
}) {
RemoveResult::Success => (),
RemoveResult::IsLast => {
let _: Bound<S> = addr_entry.remove();
}
}
}
pub fn try_update_sharing(
&mut self,
old_sharing_state: &SocketType::SharingState,
new_sharing_state: SocketType::SharingState,
) -> Result<(), UpdateSharingError>
where
SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
{
let Self { id, addr_entry, _marker } = self;
let addr = SocketType::from_addr_vec_ref(addr_entry.key());
S::allows_sharing_update(
addr_entry.get_map(),
addr,
old_sharing_state,
&new_sharing_state,
)?;
addr_entry
.map_mut(|value| {
let value = match SocketType::from_bound_mut(value) {
Some(value) => value,
None => unreachable!("found invalid state {:?}", value),
};
value.try_update_sharing(
SocketType::from_socket_id_ref(id).clone(),
&new_sharing_state,
)
})
.map_err(|IncompatibleError| UpdateSharingError)
}
}
impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
where
AddrVec<I, D, A>: IterShadows,
S: SocketMapStateSpec,
{
pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
where
S: SocketMapConflictPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
<S as SocketMapStateSpec>::ListenerSharingState,
I,
D,
A,
>,
S::ListenerAddrState:
SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
where
S: SocketMapConflictPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
<S as SocketMapStateSpec>::ListenerSharingState,
I,
D,
A,
>,
S::ListenerAddrState:
SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
where
S: SocketMapConflictPolicy<
ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
<S as SocketMapStateSpec>::ConnSharingState,
I,
D,
A,
>,
S::ConnAddrState:
SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
where
S: SocketMapConflictPolicy<
ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
<S as SocketMapStateSpec>::ConnSharingState,
I,
D,
A,
>,
S::ConnAddrState:
SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
#[cfg(test)]
pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
let Self { addr_to_state } = self;
addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
}
pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
let Self { addr_to_state } = self;
addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
}
}
pub enum FoundSockets<A, It> {
Single(A),
Multicast(It),
}
#[allow(missing_docs)]
#[derive(Debug)]
pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
Conn(
&'a S::ConnAddrState,
ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
),
}
impl<I, D, A, S> BoundSocketMap<I, D, A, S>
where
I: BroadcastIpExt<Addr: MulticastAddress>,
D: DeviceIdentifier,
A: SocketMapAddrSpec,
S: SocketMapStateSpec
+ SocketMapConflictPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
<S as SocketMapStateSpec>::ListenerSharingState,
I,
D,
A,
> + SocketMapConflictPolicy<
ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
<S as SocketMapStateSpec>::ConnSharingState,
I,
D,
A,
>,
{
pub fn iter_receivers(
&self,
(src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
(dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
device: D,
broadcast: Option<I::BroadcastMarker>,
) -> Option<
FoundSockets<
AddrEntry<'_, I, D, A, S>,
impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
>,
> {
let mut matching_entries = AddrVecIter::with_device(
match (src_ip, src_port) {
(Some(specified_src_ip), Some(src_port)) => {
ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
.into()
}
_ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
},
device,
)
.filter_map(move |addr: AddrVec<I, D, A>| match addr {
AddrVec::Listen(l) => {
self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
}
AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
});
if broadcast.is_some() || dst_ip.addr().is_multicast() {
Some(FoundSockets::Multicast(matching_entries))
} else {
let single_entry: Option<_> = matching_entries.next();
single_entry.map(FoundSockets::Single)
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum InsertError {
ShadowAddrExists,
Exists,
ShadowerExists,
IndirectConflict,
}
pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
type Id;
type SharingState;
type Addr: Debug;
type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
fn to_bound(state: Self::AddrState) -> Bound<S>;
fn to_socket_id(id: Self::Id) -> SocketId<S>;
fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
}
impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
ConvertSocketMapState<I, D, A, S> for Listener
{
type Id = S::ListenerId;
type SharingState = S::ListenerSharingState;
type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
type AddrState = S::ListenerAddrState;
fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
AddrVec::Listen(addr.clone())
}
fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
match addr {
AddrVec::Listen(l) => l,
AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
}
}
fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
match bound {
Bound::Listen(l) => Some(l),
Bound::Conn(_c) => None,
}
}
fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
match bound {
Bound::Listen(l) => Some(l),
Bound::Conn(_c) => None,
}
}
fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
Bound::Listen(state)
}
fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
match id {
SocketId::Listener(id) => id,
SocketId::Connection(_) => unreachable!("connection ID for listener"),
}
}
fn to_socket_id(id: Self::Id) -> SocketId<S> {
SocketId::Listener(id)
}
}
impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
ConvertSocketMapState<I, D, A, S> for Connection
{
type Id = S::ConnId;
type SharingState = S::ConnSharingState;
type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
type AddrState = S::ConnAddrState;
fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
AddrVec::Conn(addr.clone())
}
fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
match addr {
AddrVec::Conn(c) => c,
AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
}
}
fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
match bound {
Bound::Listen(_l) => None,
Bound::Conn(c) => Some(c),
}
}
fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
match bound {
Bound::Listen(_l) => None,
Bound::Conn(c) => Some(c),
}
}
fn to_bound(state: S::ConnAddrState) -> Bound<S> {
Bound::Conn(state)
}
fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
match id {
SocketId::Connection(id) => id,
SocketId::Listener(_) => unreachable!("listener ID for connection"),
}
}
fn to_socket_id(id: Self::Id) -> SocketId<S> {
SocketId::Connection(id)
}
}
#[cfg(test)]
mod tests {
use alloc::collections::HashSet;
use alloc::vec;
use alloc::vec::Vec;
use assert_matches::assert_matches;
use net_declare::{net_ip_v4, net_ip_v6};
use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
use test_case::test_case;
use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
use crate::testutil::set_logger_for_test;
use super::*;
#[test_case(net_ip_v4!("8.8.8.8"))]
#[test_case(net_ip_v4!("127.0.0.1"))]
#[test_case(net_ip_v4!("127.0.8.9"))]
#[test_case(net_ip_v4!("224.1.2.3"))]
fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
let addr = SpecifiedAddr::new(addr).unwrap();
assert_eq!(addr.must_have_zone(), false);
}
#[test_case(net_ip_v6!("1::2:3"), false)]
#[test_case(net_ip_v6!("::1"), false; "localhost")]
#[test_case(net_ip_v6!("1::"), false)]
#[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
#[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
#[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
#[test_case(net_ip_v6!("fe80::1"), true)]
fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
let addr = SpecifiedAddr::new(addr).unwrap();
assert_eq!(addr.must_have_zone(), must_have);
}
#[test]
fn try_into_null_zoned_ipv6() {
assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
const ZONE: u32 = 5;
assert_eq!(
zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
Some(AddrAndZone::new(zoned, ZONE).unwrap())
);
}
enum FakeSpec {}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
struct Listener(usize);
#[derive(PartialEq, Eq, Debug)]
struct Multiple<T>(char, Vec<T>);
impl<T> Multiple<T> {
fn tag(&self) -> char {
let Multiple(c, _) = self;
*c
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
struct Conn(usize);
enum FakeAddrSpec {}
impl SocketMapAddrSpec for FakeAddrSpec {
type LocalIdentifier = NonZeroU16;
type RemoteIdentifier = ();
}
impl SocketMapStateSpec for FakeSpec {
type AddrVecTag = char;
type ListenerId = Listener;
type ConnId = Conn;
type ListenerSharingState = char;
type ConnSharingState = char;
type ListenerAddrState = Multiple<Listener>;
type ConnAddrState = Multiple<Conn>;
fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
state.tag()
}
fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
state.tag()
}
}
type FakeBoundSocketMap =
BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
#[derive(Default)]
struct FakeSocketIdGen {
next_id: usize,
}
impl FakeSocketIdGen {
fn next(&mut self) -> usize {
let next_next_id = self.next_id + 1;
core::mem::replace(&mut self.next_id, next_next_id)
}
}
impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
type Id = I;
type SharingState = char;
type Inserter<'a>
= &'a mut Vec<I>
where
I: 'a;
fn new(new_sharing_state: &char, id: I) -> Self {
Self(*new_sharing_state, vec![id])
}
fn contains_id(&self, id: &Self::Id) -> bool {
self.1.contains(id)
}
fn try_get_inserter<'a, 'b>(
&'b mut self,
new_state: &'a char,
) -> Result<Self::Inserter<'b>, IncompatibleError> {
let Self(c, v) = self;
(new_state == c).then_some(v).ok_or(IncompatibleError)
}
fn could_insert(
&self,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError> {
let Self(c, _) = self;
(new_sharing_state == c).then_some(()).ok_or(IncompatibleError)
}
fn remove_by_id(&mut self, id: I) -> RemoveResult {
let Self(_, v) = self;
let index = v.iter().position(|i| i == &id).expect("did not find id");
let _: I = v.swap_remove(index);
if v.is_empty() {
RemoveResult::IsLast
} else {
RemoveResult::Success
}
}
}
impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
SocketMapConflictPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
for FakeSpec
{
fn check_insert_conflicts(
new_state: &char,
addr: &A,
socketmap: &SocketMap<
AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
Bound<FakeSpec>,
>,
) -> Result<(), InsertError> {
let dest = addr.clone().into();
if dest.iter_shadows().any(|a| socketmap.get(&a).is_some()) {
return Err(InsertError::ShadowAddrExists);
}
match socketmap.get(&dest) {
Some(Bound::Listen(Multiple(c, _))) | Some(Bound::Conn(Multiple(c, _))) => {
if c != new_state {
return Err(InsertError::Exists);
}
}
None => (),
}
if socketmap.descendant_counts(&dest).len() != 0 {
Err(InsertError::ShadowerExists)
} else {
Ok(())
}
}
}
impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
fn try_update_sharing(
&mut self,
id: Self::Id,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError> {
let Self(sharing, v) = self;
if new_sharing_state == sharing {
return Ok(());
}
if v.len() != 1 {
return Err(IncompatibleError);
}
assert!(v.contains(&id));
*sharing = *new_sharing_state;
Ok(())
}
}
impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
SocketMapUpdateSharingPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
for FakeSpec
{
fn allows_sharing_update(
_socketmap: &SocketMap<
AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
Bound<Self>,
>,
_addr: &A,
_old_sharing: &char,
_new_sharing_state: &char,
) -> Result<(), UpdateSharingError> {
Ok(())
}
}
const LISTENER_ADDR: ListenerAddr<
ListenerIpAddr<Ipv4Addr, NonZeroU16>,
FakeWeakDeviceId<FakeDeviceId>,
> = ListenerAddr {
ip: ListenerIpAddr {
addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
identifier: NonZeroU16::new(1).unwrap(),
},
device: None,
};
const CONN_ADDR: ConnAddr<
ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
FakeWeakDeviceId<FakeDeviceId>,
> = ConnAddr {
ip: ConnIpAddr {
local: (
unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
NonZeroU16::new(1).unwrap(),
),
remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
},
device: None,
};
#[test]
fn bound_insert_get_remove_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let id = {
let entry =
bound.listeners_mut().try_insert(addr, 'v', Listener(fake_id_gen.next())).unwrap();
assert_eq!(entry.get_addr(), &addr);
entry.id().clone()
};
assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
assert_eq!(bound.listeners().get_by_addr(&addr), None);
}
#[test]
fn bound_insert_get_remove_conn() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let id = {
let entry = bound.conns_mut().try_insert(addr, 'v', Conn(fake_id_gen.next())).unwrap();
assert_eq!(entry.get_addr(), &addr);
entry.id().clone()
};
assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
assert_eq!(bound.conns().get_by_addr(&addr), None);
}
#[test]
fn bound_iter_addrs() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let listener_addrs = [
(Some(net_ip_v4!("1.1.1.1")), 1),
(Some(net_ip_v4!("2.2.2.2")), 2),
(Some(net_ip_v4!("1.1.1.1")), 3),
(None, 4),
]
.map(|(ip, identifier)| ListenerAddr {
device: None,
ip: ListenerIpAddr {
addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
identifier: NonZeroU16::new(identifier).unwrap(),
},
});
let conn_addrs = [
(net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
(net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
]
.map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
ip: ConnIpAddr {
local: (
SocketIpAddr::new(local_ip).unwrap(),
NonZeroU16::new(local_identifier).unwrap(),
),
remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
},
device: None,
});
for addr in listener_addrs.iter().cloned() {
let _entry =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap();
}
for addr in conn_addrs.iter().cloned() {
let _entry = bound.conns_mut().try_insert(addr, 'a', Conn(fake_id_gen.next())).unwrap();
}
let expected_addrs = listener_addrs
.into_iter()
.map(Into::into)
.chain(conn_addrs.into_iter().map(Into::into))
.collect::<HashSet<_>>();
assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
}
#[test]
fn try_insert_with_callback_not_called_on_error() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let addr = LISTENER_ADDR;
let _: &Listener = bound.listeners_mut().try_insert(addr, 'a', Listener(0)).unwrap().id();
fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
panic!("should never be called");
}
assert_matches!(
bound.listeners_mut().try_insert_with(addr, 'b', is_never_called),
Err((InsertError::Exists, _))
);
assert_matches!(
bound.listeners_mut().try_insert_with(
ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
'b',
is_never_called
),
Err((InsertError::ShadowAddrExists, _))
);
assert_matches!(
bound.conns_mut().try_insert_with(
ConnAddr {
device: None,
ip: ConnIpAddr {
local: (addr.ip.addr.unwrap(), addr.ip.identifier),
remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
},
},
'b',
is_never_called,
),
Err((InsertError::ShadowAddrExists, _))
);
}
#[test]
fn insert_listener_conflict_with_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let _: &Listener =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
assert_matches!(
bound.listeners_mut().try_insert(addr, 'b', Listener(fake_id_gen.next())),
Err((InsertError::Exists, 'b'))
);
}
#[test]
fn insert_listener_conflict_with_shadower() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let shadows_addr = {
assert_eq!(addr.device, None);
ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
};
let _: &Listener =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
assert_matches!(
bound.listeners_mut().try_insert(shadows_addr, 'b', Listener(fake_id_gen.next())),
Err((InsertError::ShadowAddrExists, 'b'))
);
}
#[test]
fn insert_conn_conflict_with_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let shadows_addr = ConnAddr {
device: None,
ip: ConnIpAddr {
local: (addr.ip.addr.unwrap(), addr.ip.identifier),
remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
},
};
let _: &Listener =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
assert_matches!(
bound.conns_mut().try_insert(shadows_addr, 'b', Conn(fake_id_gen.next())),
Err((InsertError::ShadowAddrExists, 'b'))
);
}
#[test]
fn insert_and_remove_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let a = bound
.listeners_mut()
.try_insert(addr, 'x', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
let b = bound
.listeners_mut()
.try_insert(addr, 'x', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
assert_ne!(a, b);
assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
}
#[test]
fn insert_and_remove_conn() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let a =
bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
let b =
bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
assert_ne!(a, b);
assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
}
#[test]
fn update_listener_to_shadowed_addr_fails() {
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let first_addr = LISTENER_ADDR;
let second_addr = ListenerAddr {
ip: ListenerIpAddr {
addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
..LISTENER_ADDR.ip
},
..LISTENER_ADDR
};
let both_shadow = ListenerAddr {
ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
device: None,
};
let first = bound
.listeners_mut()
.try_insert(first_addr, 'a', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
let second = bound
.listeners_mut()
.try_insert(second_addr, 'b', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
let (ExistsError, entry) = bound
.listeners_mut()
.entry(&second, &second_addr)
.unwrap()
.try_update_addr(both_shadow)
.expect_err("update should fail");
assert_eq!(entry.id(), &second);
drop(entry);
let (ExistsError, entry) = bound
.listeners_mut()
.entry(&first, &first_addr)
.unwrap()
.try_update_addr(both_shadow)
.expect_err("update should fail");
assert_eq!(entry.get_addr(), &first_addr);
}
#[test]
fn nonexistent_conn_entry() {
let mut map = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let conn_id = map
.conns_mut()
.try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
.expect("failed to insert")
.id()
.clone();
assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
}
#[test]
fn update_conn_sharing() {
let mut map = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let mut entry = map
.conns_mut()
.try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
.expect("failed to insert");
entry.try_update_sharing(&'a', 'd').expect("worked");
let mut second_conn = map
.conns_mut()
.try_insert(addr.clone(), 'd', Conn(fake_id_gen.next()))
.expect("can insert");
assert_matches!(second_conn.try_update_sharing(&'d', 'e'), Err(UpdateSharingError));
}
}