pub(crate) mod accept_queue;
pub(crate) mod demux;
pub(crate) mod isn;
use alloc::collections::{hash_map, HashMap};
use core::convert::Infallible as Never;
use core::fmt::{self, Debug};
use core::marker::PhantomData;
use core::num::{NonZeroU16, NonZeroUsize};
use core::ops::{Deref, DerefMut, RangeInclusive};
use assert_matches::assert_matches;
use derivative::Derivative;
use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
use log::{debug, error, trace};
use net_types::ip::{
GenericOverIp, Ip, IpAddr, IpAddress, IpVersion, IpVersionMarker, Ipv4, Ipv4Addr, Ipv6,
Ipv6Addr,
};
use net_types::{AddrAndPortFormatter, AddrAndZone, SpecifiedAddr, ZonedAddr};
use netstack3_base::socket::{
self, AddrIsMappedError, AddrVec, Bound, ConnAddr, ConnIpAddr, DualStackListenerIpAddr,
DualStackLocalIp, DualStackRemoteIp, DualStackTuple, EitherStack, IncompatibleError,
InsertError, Inserter, ListenerAddr, ListenerAddrInfo, ListenerIpAddr, MaybeDualStack,
NotDualStackCapableError, RemoveResult, SetDualStackEnabledError, ShutdownType,
SocketDeviceUpdate, SocketDeviceUpdateNotAllowedError, SocketIpAddr, SocketIpExt,
SocketMapAddrSpec, SocketMapAddrStateSpec, SocketMapAddrStateUpdateSharingSpec,
SocketMapConflictPolicy, SocketMapStateSpec, SocketMapUpdateSharingPolicy,
SocketZonedAddrExt as _, UpdateSharingError,
};
use netstack3_base::socketmap::{IterShadows as _, SocketMap};
use netstack3_base::sync::RwLock;
use netstack3_base::{
trace_duration, AnyDevice, BidirectionalConverter as _, ContextPair, Control, CoreTimerContext,
CounterContext, CtxPair, DeferredResourceRemovalContext, DeviceIdContext, EitherDeviceId,
ExistsError, HandleableTimer, IcmpErrorCode, Inspector, InspectorDeviceExt,
InstantBindingsTypes, IpDeviceAddr, IpExt, LocalAddressError, Mss,
OwnedOrRefsBidirectionalConverter, PortAllocImpl, ReferenceNotifiersExt as _,
RemoveResourceResult, RngContext, Segment, SeqNum, StrongDeviceIdentifier as _,
TimerBindingsTypes, TimerContext, TracingContext, WeakDeviceIdentifier, ZonedAddressError,
};
use netstack3_filter::Tuple;
use netstack3_ip::socket::{
DeviceIpSocketHandler, IpSock, IpSockCreateAndSendError, IpSockCreationError, IpSocketHandler,
};
use netstack3_ip::{self as ip, BaseTransportIpContext, Mark, MarkDomain, TransportIpContext};
use packet_formats::ip::IpProto;
use smallvec::{smallvec, SmallVec};
use thiserror::Error;
use crate::internal::base::{
BufferSizes, BuffersRefMut, ConnectionError, SocketOptions, TcpCounters, TcpIpSockOptions,
};
use crate::internal::buffer::{Buffer, IntoBuffers, ReceiveBuffer, SendBuffer};
use crate::internal::socket::accept_queue::{AcceptQueue, ListenerNotifier};
use crate::internal::socket::demux::tcp_serialize_segment;
use crate::internal::socket::isn::IsnGenerator;
use crate::internal::state::{
CloseError, CloseReason, Closed, Initial, NewlyClosed, State, Takeable, TakeableRef,
};
pub trait DualStackIpExt:
DualStackBaseIpExt + netstack3_base::socket::DualStackIpExt<OtherVersion: DualStackBaseIpExt>
{
}
impl<I> DualStackIpExt for I where
I: DualStackBaseIpExt
+ netstack3_base::socket::DualStackIpExt<OtherVersion: DualStackBaseIpExt>
{
}
pub trait DualStackBaseIpExt:
netstack3_base::socket::DualStackIpExt + SocketIpExt + netstack3_base::IpExt
{
type DemuxSocketId<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>: SpecSocketId;
type ConnectionAndAddr<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>: Send + Sync + Debug;
type ListenerIpAddr: Send + Sync + Debug + Clone;
type OriginalDstAddr;
type DualStackIpOptions: Send + Sync + Debug + Default + Clone + Copy;
fn as_dual_stack_ip_socket<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: &Self::DemuxSocketId<D, BT>,
) -> EitherStack<&TcpSocketId<Self, D, BT>, &TcpSocketId<Self::OtherVersion, D, BT>>
where
Self::OtherVersion: DualStackBaseIpExt;
fn into_dual_stack_ip_socket<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: Self::DemuxSocketId<D, BT>,
) -> EitherStack<TcpSocketId<Self, D, BT>, TcpSocketId<Self::OtherVersion, D, BT>>
where
Self::OtherVersion: DualStackBaseIpExt;
fn into_demux_socket_id<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: TcpSocketId<Self, D, BT>,
) -> Self::DemuxSocketId<D, BT>
where
Self::OtherVersion: DualStackBaseIpExt;
fn get_conn_info<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &Self::ConnectionAndAddr<D, BT>,
) -> ConnectionInfo<Self::Addr, D>;
fn get_accept_queue_mut<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &mut Self::ConnectionAndAddr<D, BT>,
) -> &mut Option<
AcceptQueue<
TcpSocketId<Self, D, BT>,
BT::ReturnedBuffers,
BT::ListenerNotifierOrProvidedBuffers,
>,
>
where
Self::OtherVersion: DualStackBaseIpExt;
fn get_defunct<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &Self::ConnectionAndAddr<D, BT>,
) -> bool;
fn get_state<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &Self::ConnectionAndAddr<D, BT>,
) -> &State<BT::Instant, BT::ReceiveBuffer, BT::SendBuffer, BT::ListenerNotifierOrProvidedBuffers>;
fn get_bound_info<D: WeakDeviceIdentifier>(
listener_addr: &ListenerAddr<Self::ListenerIpAddr, D>,
) -> BoundInfo<Self::Addr, D>;
fn destroy_socket_with_demux_id<
CC: TcpContext<Self, BC> + TcpContext<Self::OtherVersion, BC>,
BC: TcpBindingsContext,
>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
demux_id: Self::DemuxSocketId<CC::WeakDeviceId, BC>,
) where
Self::OtherVersion: DualStackBaseIpExt;
fn get_original_dst(addr: Self::OriginalDstAddr) -> Self::Addr;
}
impl DualStackBaseIpExt for Ipv4 {
type DemuxSocketId<D: WeakDeviceIdentifier, BT: TcpBindingsTypes> =
EitherStack<TcpSocketId<Ipv4, D, BT>, TcpSocketId<Ipv6, D, BT>>;
type ConnectionAndAddr<D: WeakDeviceIdentifier, BT: TcpBindingsTypes> =
(Connection<Ipv4, Ipv4, D, BT>, ConnAddr<ConnIpAddr<Ipv4Addr, NonZeroU16, NonZeroU16>, D>);
type ListenerIpAddr = ListenerIpAddr<Ipv4Addr, NonZeroU16>;
type OriginalDstAddr = Ipv4Addr;
type DualStackIpOptions = ();
fn as_dual_stack_ip_socket<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: &Self::DemuxSocketId<D, BT>,
) -> EitherStack<&TcpSocketId<Self, D, BT>, &TcpSocketId<Self::OtherVersion, D, BT>> {
match id {
EitherStack::ThisStack(id) => EitherStack::ThisStack(id),
EitherStack::OtherStack(id) => EitherStack::OtherStack(id),
}
}
fn into_dual_stack_ip_socket<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: Self::DemuxSocketId<D, BT>,
) -> EitherStack<TcpSocketId<Self, D, BT>, TcpSocketId<Self::OtherVersion, D, BT>> {
id
}
fn into_demux_socket_id<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: TcpSocketId<Self, D, BT>,
) -> Self::DemuxSocketId<D, BT> {
EitherStack::ThisStack(id)
}
fn get_conn_info<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
(_conn, addr): &Self::ConnectionAndAddr<D, BT>,
) -> ConnectionInfo<Self::Addr, D> {
addr.clone().into()
}
fn get_accept_queue_mut<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
(conn, _addr): &mut Self::ConnectionAndAddr<D, BT>,
) -> &mut Option<
AcceptQueue<
TcpSocketId<Self, D, BT>,
BT::ReturnedBuffers,
BT::ListenerNotifierOrProvidedBuffers,
>,
> {
&mut conn.accept_queue
}
fn get_defunct<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
(conn, _addr): &Self::ConnectionAndAddr<D, BT>,
) -> bool {
conn.defunct
}
fn get_state<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
(conn, _addr): &Self::ConnectionAndAddr<D, BT>,
) -> &State<BT::Instant, BT::ReceiveBuffer, BT::SendBuffer, BT::ListenerNotifierOrProvidedBuffers>
{
&conn.state
}
fn get_bound_info<D: WeakDeviceIdentifier>(
listener_addr: &ListenerAddr<Self::ListenerIpAddr, D>,
) -> BoundInfo<Self::Addr, D> {
listener_addr.clone().into()
}
fn destroy_socket_with_demux_id<
CC: TcpContext<Self, BC> + TcpContext<Self::OtherVersion, BC>,
BC: TcpBindingsContext,
>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
demux_id: Self::DemuxSocketId<CC::WeakDeviceId, BC>,
) {
match demux_id {
EitherStack::ThisStack(id) => destroy_socket(core_ctx, bindings_ctx, id),
EitherStack::OtherStack(id) => destroy_socket(core_ctx, bindings_ctx, id),
}
}
fn get_original_dst(addr: Self::OriginalDstAddr) -> Self::Addr {
addr
}
}
#[derive(Derivative, Debug, Clone, Copy, PartialEq, Eq)]
#[derivative(Default)]
pub struct Ipv6Options {
#[derivative(Default(value = "true"))]
pub dual_stack_enabled: bool,
}
impl DualStackBaseIpExt for Ipv6 {
type DemuxSocketId<D: WeakDeviceIdentifier, BT: TcpBindingsTypes> = TcpSocketId<Ipv6, D, BT>;
type ConnectionAndAddr<D: WeakDeviceIdentifier, BT: TcpBindingsTypes> = EitherStack<
(Connection<Ipv6, Ipv6, D, BT>, ConnAddr<ConnIpAddr<Ipv6Addr, NonZeroU16, NonZeroU16>, D>),
(Connection<Ipv6, Ipv4, D, BT>, ConnAddr<ConnIpAddr<Ipv4Addr, NonZeroU16, NonZeroU16>, D>),
>;
type DualStackIpOptions = Ipv6Options;
type ListenerIpAddr = DualStackListenerIpAddr<Ipv6Addr, NonZeroU16>;
type OriginalDstAddr = EitherStack<Ipv6Addr, Ipv4Addr>;
fn as_dual_stack_ip_socket<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: &Self::DemuxSocketId<D, BT>,
) -> EitherStack<&TcpSocketId<Self, D, BT>, &TcpSocketId<Self::OtherVersion, D, BT>> {
EitherStack::ThisStack(id)
}
fn into_dual_stack_ip_socket<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: Self::DemuxSocketId<D, BT>,
) -> EitherStack<TcpSocketId<Self, D, BT>, TcpSocketId<Self::OtherVersion, D, BT>> {
EitherStack::ThisStack(id)
}
fn into_demux_socket_id<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
id: TcpSocketId<Self, D, BT>,
) -> Self::DemuxSocketId<D, BT> {
id
}
fn get_conn_info<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &Self::ConnectionAndAddr<D, BT>,
) -> ConnectionInfo<Self::Addr, D> {
match conn_and_addr {
EitherStack::ThisStack((_conn, addr)) => addr.clone().into(),
EitherStack::OtherStack((
_conn,
ConnAddr {
ip:
ConnIpAddr { local: (local_ip, local_port), remote: (remote_ip, remote_port) },
device,
},
)) => ConnectionInfo {
local_addr: SocketAddr {
ip: maybe_zoned(local_ip.addr().to_ipv6_mapped(), device),
port: *local_port,
},
remote_addr: SocketAddr {
ip: maybe_zoned(remote_ip.addr().to_ipv6_mapped(), device),
port: *remote_port,
},
device: device.clone(),
},
}
}
fn get_accept_queue_mut<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &mut Self::ConnectionAndAddr<D, BT>,
) -> &mut Option<
AcceptQueue<
TcpSocketId<Self, D, BT>,
BT::ReturnedBuffers,
BT::ListenerNotifierOrProvidedBuffers,
>,
> {
match conn_and_addr {
EitherStack::ThisStack((conn, _addr)) => &mut conn.accept_queue,
EitherStack::OtherStack((conn, _addr)) => &mut conn.accept_queue,
}
}
fn get_defunct<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &Self::ConnectionAndAddr<D, BT>,
) -> bool {
match conn_and_addr {
EitherStack::ThisStack((conn, _addr)) => conn.defunct,
EitherStack::OtherStack((conn, _addr)) => conn.defunct,
}
}
fn get_state<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
conn_and_addr: &Self::ConnectionAndAddr<D, BT>,
) -> &State<BT::Instant, BT::ReceiveBuffer, BT::SendBuffer, BT::ListenerNotifierOrProvidedBuffers>
{
match conn_and_addr {
EitherStack::ThisStack((conn, _addr)) => &conn.state,
EitherStack::OtherStack((conn, _addr)) => &conn.state,
}
}
fn get_bound_info<D: WeakDeviceIdentifier>(
ListenerAddr { ip, device }: &ListenerAddr<Self::ListenerIpAddr, D>,
) -> BoundInfo<Self::Addr, D> {
match ip {
DualStackListenerIpAddr::ThisStack(ip) => {
ListenerAddr { ip: ip.clone(), device: device.clone() }.into()
}
DualStackListenerIpAddr::OtherStack(ListenerIpAddr {
addr,
identifier: local_port,
}) => BoundInfo {
addr: Some(maybe_zoned(
addr.map(|a| a.addr()).unwrap_or(Ipv4::UNSPECIFIED_ADDRESS).to_ipv6_mapped(),
&device,
)),
port: *local_port,
device: device.clone(),
},
DualStackListenerIpAddr::BothStacks(local_port) => {
BoundInfo { addr: None, port: *local_port, device: device.clone() }
}
}
}
fn destroy_socket_with_demux_id<
CC: TcpContext<Self, BC> + TcpContext<Self::OtherVersion, BC>,
BC: TcpBindingsContext,
>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
demux_id: Self::DemuxSocketId<CC::WeakDeviceId, BC>,
) {
destroy_socket(core_ctx, bindings_ctx, demux_id)
}
fn get_original_dst(addr: Self::OriginalDstAddr) -> Self::Addr {
match addr {
EitherStack::ThisStack(addr) => addr,
EitherStack::OtherStack(addr) => *addr.to_ipv6_mapped(),
}
}
}
#[derive(Derivative, GenericOverIp)]
#[generic_over_ip()]
#[derivative(
Clone(bound = ""),
Eq(bound = ""),
PartialEq(bound = ""),
Hash(bound = ""),
Debug(bound = "")
)]
#[allow(missing_docs)]
pub enum TcpTimerId<D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
V4(WeakTcpSocketId<Ipv4, D, BT>),
V6(WeakTcpSocketId<Ipv6, D, BT>),
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
From<WeakTcpSocketId<I, D, BT>> for TcpTimerId<D, BT>
{
fn from(f: WeakTcpSocketId<I, D, BT>) -> Self {
I::map_ip(f, TcpTimerId::V4, TcpTimerId::V6)
}
}
pub trait TcpBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
type ReceiveBuffer: ReceiveBuffer + Send + Sync;
type SendBuffer: SendBuffer + Send + Sync;
type ReturnedBuffers: Debug + Send + Sync;
type ListenerNotifierOrProvidedBuffers: Debug
+ IntoBuffers<Self::ReceiveBuffer, Self::SendBuffer>
+ ListenerNotifier
+ Send
+ Sync;
fn default_buffer_sizes() -> BufferSizes;
fn new_passive_open_buffers(
buffer_sizes: BufferSizes,
) -> (Self::ReceiveBuffer, Self::SendBuffer, Self::ReturnedBuffers);
}
pub trait TcpBindingsContext:
Sized
+ DeferredResourceRemovalContext
+ TimerContext
+ TracingContext
+ RngContext
+ TcpBindingsTypes
{
}
impl<BC> TcpBindingsContext for BC where
BC: Sized
+ DeferredResourceRemovalContext
+ TimerContext
+ TracingContext
+ RngContext
+ TcpBindingsTypes
{
}
pub trait TcpDemuxContext<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>:
TcpCoreTimerContext<I, D, BT>
{
type IpTransportCtx<'a>: TransportIpContext<I, BT, DeviceId = D::Strong, WeakDeviceId = D>
+ DeviceIpSocketHandler<I, BT>
+ TcpCoreTimerContext<I, D, BT>
+ CounterContext<TcpCounters<I>>
+ CounterContext<TcpCounters<I::OtherVersion>>;
fn with_demux<O, F: FnOnce(&DemuxState<I, D, BT>) -> O>(&mut self, cb: F) -> O;
fn with_demux_mut<O, F: FnOnce(&mut DemuxState<I, D, BT>) -> O>(&mut self, cb: F) -> O;
}
pub trait AsThisStack<T> {
fn as_this_stack(&mut self) -> &mut T;
}
impl<T> AsThisStack<T> for T {
fn as_this_stack(&mut self) -> &mut T {
self
}
}
pub trait TcpCoreTimerContext<I: DualStackIpExt, D: WeakDeviceIdentifier, BC: TcpBindingsTypes>:
CoreTimerContext<WeakTcpSocketId<I, D, BC>, BC>
{
}
impl<CC, I, D, BC> TcpCoreTimerContext<I, D, BC> for CC
where
I: DualStackIpExt,
D: WeakDeviceIdentifier,
BC: TcpBindingsTypes,
CC: CoreTimerContext<WeakTcpSocketId<I, D, BC>, BC>,
{
}
pub trait DualStackConverter<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>:
OwnedOrRefsBidirectionalConverter<
I::ConnectionAndAddr<D, BT>,
EitherStack<
(
Connection<I, I, D, BT>,
ConnAddr<ConnIpAddr<<I as Ip>::Addr, NonZeroU16, NonZeroU16>, D>,
),
(
Connection<I, I::OtherVersion, D, BT>,
ConnAddr<ConnIpAddr<<I::OtherVersion as Ip>::Addr, NonZeroU16, NonZeroU16>, D>,
),
>,
> + OwnedOrRefsBidirectionalConverter<
I::ListenerIpAddr,
DualStackListenerIpAddr<I::Addr, NonZeroU16>,
> + OwnedOrRefsBidirectionalConverter<
ListenerAddr<I::ListenerIpAddr, D>,
ListenerAddr<DualStackListenerIpAddr<I::Addr, NonZeroU16>, D>,
> + OwnedOrRefsBidirectionalConverter<
I::OriginalDstAddr,
EitherStack<I::Addr, <I::OtherVersion as Ip>::Addr>,
>
{
}
impl<I, D, BT, O> DualStackConverter<I, D, BT> for O
where
I: DualStackIpExt,
D: WeakDeviceIdentifier,
BT: TcpBindingsTypes,
O: OwnedOrRefsBidirectionalConverter<
I::ConnectionAndAddr<D, BT>,
EitherStack<
(
Connection<I, I, D, BT>,
ConnAddr<ConnIpAddr<<I as Ip>::Addr, NonZeroU16, NonZeroU16>, D>,
),
(
Connection<I, I::OtherVersion, D, BT>,
ConnAddr<ConnIpAddr<<I::OtherVersion as Ip>::Addr, NonZeroU16, NonZeroU16>, D>,
),
>,
> + OwnedOrRefsBidirectionalConverter<
I::ListenerIpAddr,
DualStackListenerIpAddr<I::Addr, NonZeroU16>,
> + OwnedOrRefsBidirectionalConverter<
ListenerAddr<I::ListenerIpAddr, D>,
ListenerAddr<DualStackListenerIpAddr<I::Addr, NonZeroU16>, D>,
> + OwnedOrRefsBidirectionalConverter<
I::OriginalDstAddr,
EitherStack<I::Addr, <I::OtherVersion as Ip>::Addr>,
>,
{
}
pub trait SingleStackConverter<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>:
OwnedOrRefsBidirectionalConverter<
I::ConnectionAndAddr<D, BT>,
(Connection<I, I, D, BT>, ConnAddr<ConnIpAddr<<I as Ip>::Addr, NonZeroU16, NonZeroU16>, D>),
> + OwnedOrRefsBidirectionalConverter<I::ListenerIpAddr, ListenerIpAddr<I::Addr, NonZeroU16>>
+ OwnedOrRefsBidirectionalConverter<
ListenerAddr<I::ListenerIpAddr, D>,
ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>,
> + OwnedOrRefsBidirectionalConverter<I::OriginalDstAddr, I::Addr>
{
}
impl<I, D, BT, O> SingleStackConverter<I, D, BT> for O
where
I: DualStackIpExt,
D: WeakDeviceIdentifier,
BT: TcpBindingsTypes,
O: OwnedOrRefsBidirectionalConverter<
I::ConnectionAndAddr<D, BT>,
(
Connection<I, I, D, BT>,
ConnAddr<ConnIpAddr<<I as Ip>::Addr, NonZeroU16, NonZeroU16>, D>,
),
> + OwnedOrRefsBidirectionalConverter<I::ListenerIpAddr, ListenerIpAddr<I::Addr, NonZeroU16>>
+ OwnedOrRefsBidirectionalConverter<
ListenerAddr<I::ListenerIpAddr, D>,
ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>,
> + OwnedOrRefsBidirectionalConverter<I::OriginalDstAddr, I::Addr>,
{
}
pub trait TcpContext<I: DualStackIpExt, BC: TcpBindingsTypes>:
TcpDemuxContext<I, Self::WeakDeviceId, BC> + IpSocketHandler<I, BC>
{
type ThisStackIpTransportAndDemuxCtx<'a>: TransportIpContext<I, BC, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
+ DeviceIpSocketHandler<I, BC>
+ TcpDemuxContext<I, Self::WeakDeviceId, BC>
+ CounterContext<TcpCounters<I>>;
type SingleStackIpTransportAndDemuxCtx<'a>: TransportIpContext<I, BC, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
+ DeviceIpSocketHandler<I, BC>
+ TcpDemuxContext<I, Self::WeakDeviceId, BC>
+ AsThisStack<Self::ThisStackIpTransportAndDemuxCtx<'a>>
+ CounterContext<TcpCounters<I>>;
type SingleStackConverter: SingleStackConverter<I, Self::WeakDeviceId, BC>;
type DualStackIpTransportAndDemuxCtx<'a>: TransportIpContext<I, BC, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
+ DeviceIpSocketHandler<I, BC>
+ TcpDemuxContext<I, Self::WeakDeviceId, BC>
+ TransportIpContext<
I::OtherVersion,
BC,
DeviceId = Self::DeviceId,
WeakDeviceId = Self::WeakDeviceId,
> + DeviceIpSocketHandler<I::OtherVersion, BC>
+ TcpDemuxContext<I::OtherVersion, Self::WeakDeviceId, BC>
+ TcpDualStackContext<I, Self::WeakDeviceId, BC>
+ AsThisStack<Self::ThisStackIpTransportAndDemuxCtx<'a>>
+ CounterContext<TcpCounters<I>>
+ CounterContext<TcpCounters<I::OtherVersion>>;
type DualStackConverter: DualStackConverter<I, Self::WeakDeviceId, BC>;
fn with_all_sockets_mut<O, F: FnOnce(&mut TcpSocketSet<I, Self::WeakDeviceId, BC>) -> O>(
&mut self,
cb: F,
) -> O;
fn for_each_socket<
F: FnMut(&TcpSocketId<I, Self::WeakDeviceId, BC>, &TcpSocketState<I, Self::WeakDeviceId, BC>),
>(
&mut self,
cb: F,
);
fn with_socket_mut_isn_transport_demux<
O,
F: for<'a> FnOnce(
MaybeDualStack<
(&'a mut Self::DualStackIpTransportAndDemuxCtx<'a>, Self::DualStackConverter),
(&'a mut Self::SingleStackIpTransportAndDemuxCtx<'a>, Self::SingleStackConverter),
>,
&mut TcpSocketState<I, Self::WeakDeviceId, BC>,
&IsnGenerator<BC::Instant>,
) -> O,
>(
&mut self,
id: &TcpSocketId<I, Self::WeakDeviceId, BC>,
cb: F,
) -> O;
fn with_socket<O, F: FnOnce(&TcpSocketState<I, Self::WeakDeviceId, BC>) -> O>(
&mut self,
id: &TcpSocketId<I, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
self.with_socket_and_converter(id, |socket_state, _converter| cb(socket_state))
}
fn with_socket_and_converter<
O,
F: FnOnce(
&TcpSocketState<I, Self::WeakDeviceId, BC>,
MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter>,
) -> O,
>(
&mut self,
id: &TcpSocketId<I, Self::WeakDeviceId, BC>,
cb: F,
) -> O;
fn with_socket_mut_transport_demux<
O,
F: for<'a> FnOnce(
MaybeDualStack<
(&'a mut Self::DualStackIpTransportAndDemuxCtx<'a>, Self::DualStackConverter),
(&'a mut Self::SingleStackIpTransportAndDemuxCtx<'a>, Self::SingleStackConverter),
>,
&mut TcpSocketState<I, Self::WeakDeviceId, BC>,
) -> O,
>(
&mut self,
id: &TcpSocketId<I, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
self.with_socket_mut_isn_transport_demux(id, |ctx, socket_state, _isn| {
cb(ctx, socket_state)
})
}
fn with_socket_mut<O, F: FnOnce(&mut TcpSocketState<I, Self::WeakDeviceId, BC>) -> O>(
&mut self,
id: &TcpSocketId<I, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
self.with_socket_mut_isn_transport_demux(id, |_ctx, socket_state, _isn| cb(socket_state))
}
fn with_socket_mut_and_converter<
O,
F: FnOnce(
&mut TcpSocketState<I, Self::WeakDeviceId, BC>,
MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter>,
) -> O,
>(
&mut self,
id: &TcpSocketId<I, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
self.with_socket_mut_isn_transport_demux(id, |ctx, socket_state, _isn| {
let converter = match ctx {
MaybeDualStack::NotDualStack((_core_ctx, converter)) => {
MaybeDualStack::NotDualStack(converter)
}
MaybeDualStack::DualStack((_core_ctx, converter)) => {
MaybeDualStack::DualStack(converter)
}
};
cb(socket_state, converter)
})
}
}
#[derive(Clone, Copy)]
pub struct Ipv6SocketIdToIpv4DemuxIdConverter;
pub trait DualStackDemuxIdConverter<I: DualStackIpExt>: 'static + Clone + Copy {
fn convert<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
&self,
id: TcpSocketId<I, D, BT>,
) -> <I::OtherVersion as DualStackBaseIpExt>::DemuxSocketId<D, BT>;
}
impl DualStackDemuxIdConverter<Ipv6> for Ipv6SocketIdToIpv4DemuxIdConverter {
fn convert<D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
&self,
id: TcpSocketId<Ipv6, D, BT>,
) -> <Ipv4 as DualStackBaseIpExt>::DemuxSocketId<D, BT> {
EitherStack::OtherStack(id)
}
}
pub trait TcpDualStackContext<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
type DualStackIpTransportCtx<'a>: TransportIpContext<I, BT, DeviceId = D::Strong, WeakDeviceId = D>
+ DeviceIpSocketHandler<I, BT>
+ TcpCoreTimerContext<I, D, BT>
+ TransportIpContext<I::OtherVersion, BT, DeviceId = D::Strong, WeakDeviceId = D>
+ DeviceIpSocketHandler<I::OtherVersion, BT>
+ TcpCoreTimerContext<I::OtherVersion, D, BT>
+ CounterContext<TcpCounters<I>>;
fn other_demux_id_converter(&self) -> impl DualStackDemuxIdConverter<I>;
fn into_other_demux_socket_id(
&self,
id: TcpSocketId<I, D, BT>,
) -> <I::OtherVersion as DualStackBaseIpExt>::DemuxSocketId<D, BT> {
self.other_demux_id_converter().convert(id)
}
fn dual_stack_demux_id(
&self,
id: TcpSocketId<I, D, BT>,
) -> DualStackTuple<I, DemuxSocketId<I, D, BT>> {
let this_id = DemuxSocketId::<I, _, _>(I::into_demux_socket_id(id.clone()));
let other_id = DemuxSocketId::<I::OtherVersion, _, _>(self.into_other_demux_socket_id(id));
DualStackTuple::new(this_id, other_id)
}
fn dual_stack_enabled(&self, ip_options: &I::DualStackIpOptions) -> bool;
fn set_dual_stack_enabled(&self, ip_options: &mut I::DualStackIpOptions, value: bool);
fn with_both_demux_mut<
O,
F: FnOnce(&mut DemuxState<I, D, BT>, &mut DemuxState<I::OtherVersion, D, BT>) -> O,
>(
&mut self,
cb: F,
) -> O;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, GenericOverIp)]
#[generic_over_ip(A, IpAddress)]
pub struct SocketAddr<A: IpAddress, D> {
pub ip: ZonedAddr<SpecifiedAddr<A>, D>,
pub port: NonZeroU16,
}
impl<A: IpAddress, D> From<SocketAddr<A, D>>
for IpAddr<SocketAddr<Ipv4Addr, D>, SocketAddr<Ipv6Addr, D>>
{
fn from(addr: SocketAddr<A, D>) -> IpAddr<SocketAddr<Ipv4Addr, D>, SocketAddr<Ipv6Addr, D>> {
<A::Version as Ip>::map_ip_in(addr, |i| IpAddr::V4(i), |i| IpAddr::V6(i))
}
}
impl<A: IpAddress, D> SocketAddr<A, D> {
pub fn map_zone<Y>(self, f: impl FnOnce(D) -> Y) -> SocketAddr<A, Y> {
let Self { ip, port } = self;
SocketAddr { ip: ip.map_zone(f), port }
}
}
impl<A: IpAddress, D: fmt::Display> fmt::Display for SocketAddr<A, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let Self { ip, port } = self;
let formatter = AddrAndPortFormatter::<_, _, A::Version>::new(
ip.as_ref().map_addr(core::convert::AsRef::<A>::as_ref),
port,
);
formatter.fmt(f)
}
}
pub(crate) enum TcpPortSpec {}
impl SocketMapAddrSpec for TcpPortSpec {
type RemoteIdentifier = NonZeroU16;
type LocalIdentifier = NonZeroU16;
}
pub enum TcpIpTransportContext {}
pub trait SpecSocketId: Clone + Eq + PartialEq + Debug + 'static {}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> SpecSocketId
for TcpSocketId<I, D, BT>
{
}
impl<A: SpecSocketId, B: SpecSocketId> SpecSocketId for EitherStack<A, B> {}
struct TcpSocketSpec<I, D, BT>(PhantomData<(I, D, BT)>, Never);
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> SocketMapStateSpec
for TcpSocketSpec<I, D, BT>
{
type ListenerId = I::DemuxSocketId<D, BT>;
type ConnId = I::DemuxSocketId<D, BT>;
type ListenerSharingState = ListenerSharingState;
type ConnSharingState = SharingState;
type AddrVecTag = AddrVecTag;
type ListenerAddrState = ListenerAddrState<Self::ListenerId>;
type ConnAddrState = ConnAddrState<Self::ConnId>;
fn listener_tag(
ListenerAddrInfo { has_device, specified_addr: _ }: ListenerAddrInfo,
state: &Self::ListenerAddrState,
) -> Self::AddrVecTag {
let (sharing, state) = match state {
ListenerAddrState::ExclusiveBound(_) => {
(SharingState::Exclusive, SocketTagState::Bound)
}
ListenerAddrState::ExclusiveListener(_) => {
(SharingState::Exclusive, SocketTagState::Listener)
}
ListenerAddrState::Shared { listener, bound: _ } => (
SharingState::ReuseAddress,
match listener {
Some(_) => SocketTagState::Listener,
None => SocketTagState::Bound,
},
),
};
AddrVecTag { sharing, state, has_device }
}
fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
let ConnAddrState { sharing, id: _ } = state;
AddrVecTag { sharing: *sharing, has_device, state: SocketTagState::Conn }
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct AddrVecTag {
sharing: SharingState,
state: SocketTagState,
has_device: bool,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum SocketTagState {
Conn,
Listener,
Bound,
}
#[derive(Debug)]
enum ListenerAddrState<S> {
ExclusiveBound(S),
ExclusiveListener(S),
Shared { listener: Option<S>, bound: SmallVec<[S; 1]> },
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ListenerSharingState {
pub(crate) sharing: SharingState,
pub(crate) listening: bool,
}
enum ListenerAddrInserter<'a, S> {
Listener(&'a mut Option<S>),
Bound(&'a mut SmallVec<[S; 1]>),
}
impl<'a, S> Inserter<S> for ListenerAddrInserter<'a, S> {
fn insert(self, id: S) {
match self {
Self::Listener(o) => *o = Some(id),
Self::Bound(b) => b.push(id),
}
}
}
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
pub enum BoundSocketState<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
Listener((MaybeListener<I, D, BT>, ListenerSharingState, ListenerAddr<I::ListenerIpAddr, D>)),
Connected { conn: I::ConnectionAndAddr<D, BT>, sharing: SharingState, timer: BT::Timer },
}
impl<S: SpecSocketId> SocketMapAddrStateSpec for ListenerAddrState<S> {
type SharingState = ListenerSharingState;
type Id = S;
type Inserter<'a> = ListenerAddrInserter<'a, S>;
fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self {
let ListenerSharingState { sharing, listening } = new_sharing_state;
match sharing {
SharingState::Exclusive => match listening {
true => Self::ExclusiveListener(id),
false => Self::ExclusiveBound(id),
},
SharingState::ReuseAddress => {
let (listener, bound) =
if *listening { (Some(id), Default::default()) } else { (None, smallvec![id]) };
Self::Shared { listener, bound }
}
}
}
fn contains_id(&self, id: &Self::Id) -> bool {
match self {
Self::ExclusiveBound(x) | Self::ExclusiveListener(x) => id == x,
Self::Shared { listener, bound } => {
listener.as_ref().is_some_and(|x| id == x) || bound.contains(id)
}
}
}
fn could_insert(
&self,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError> {
match self {
Self::ExclusiveBound(_) | Self::ExclusiveListener(_) => Err(IncompatibleError),
Self::Shared { listener, bound: _ } => {
let ListenerSharingState { listening: _, sharing } = new_sharing_state;
match sharing {
SharingState::Exclusive => Err(IncompatibleError),
SharingState::ReuseAddress => match listener {
Some(_) => Err(IncompatibleError),
None => Ok(()),
},
}
}
}
}
fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult {
match self {
Self::ExclusiveBound(b) => {
assert_eq!(*b, id);
RemoveResult::IsLast
}
Self::ExclusiveListener(l) => {
assert_eq!(*l, id);
RemoveResult::IsLast
}
Self::Shared { listener, bound } => {
match listener {
Some(l) if *l == id => {
*listener = None;
}
Some(_) | None => {
let index = bound.iter().position(|b| *b == id).expect("invalid socket ID");
let _: S = bound.swap_remove(index);
}
};
match (listener, bound.is_empty()) {
(Some(_), _) => RemoveResult::Success,
(None, false) => RemoveResult::Success,
(None, true) => RemoveResult::IsLast,
}
}
}
}
fn try_get_inserter<'a, 'b>(
&'b mut self,
new_sharing_state: &'a Self::SharingState,
) -> Result<Self::Inserter<'b>, IncompatibleError> {
match self {
Self::ExclusiveBound(_) | Self::ExclusiveListener(_) => Err(IncompatibleError),
Self::Shared { listener, bound } => {
let ListenerSharingState { listening, sharing } = new_sharing_state;
match sharing {
SharingState::Exclusive => Err(IncompatibleError),
SharingState::ReuseAddress => {
match listener {
Some(_) => {
Err(IncompatibleError)
}
None => Ok(match listening {
true => ListenerAddrInserter::Listener(listener),
false => ListenerAddrInserter::Bound(bound),
}),
}
}
}
}
}
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
SocketMapUpdateSharingPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>,
ListenerSharingState,
I,
D,
TcpPortSpec,
> for TcpSocketSpec<I, D, BT>
{
fn allows_sharing_update(
socketmap: &SocketMap<AddrVec<I, D, TcpPortSpec>, Bound<Self>>,
addr: &ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>,
ListenerSharingState{listening: old_listening, sharing: old_sharing}: &ListenerSharingState,
ListenerSharingState{listening: new_listening, sharing: new_sharing}: &ListenerSharingState,
) -> Result<(), UpdateSharingError> {
let ListenerAddr { device, ip } = addr;
match (old_listening, new_listening) {
(true, false) => (), (true, true) | (false, false) => (), (false, true) => {
let addr = AddrVec::Listen(addr.clone());
for a in addr.iter_shadows() {
if let Some(s) = socketmap.get(&a) {
match s {
Bound::Conn(c) => {
unreachable!("found conn state {c:?} at listener addr {a:?}")
}
Bound::Listen(l) => match l {
ListenerAddrState::ExclusiveListener(_)
| ListenerAddrState::ExclusiveBound(_) => {
return Err(UpdateSharingError);
}
ListenerAddrState::Shared { listener, bound: _ } => {
match listener {
Some(_) => {
return Err(UpdateSharingError);
}
None => (),
}
}
},
}
}
}
if socketmap.descendant_counts(&ListenerAddr { device: None, ip: *ip }.into()).any(
|(AddrVecTag { state, has_device: _, sharing: _ }, _): &(_, NonZeroUsize)| {
match state {
SocketTagState::Conn | SocketTagState::Bound => false,
SocketTagState::Listener => true,
}
},
) {
return Err(UpdateSharingError);
}
}
}
match (old_sharing, new_sharing) {
(SharingState::Exclusive, SharingState::Exclusive)
| (SharingState::ReuseAddress, SharingState::ReuseAddress)
| (SharingState::Exclusive, SharingState::ReuseAddress) => (),
(SharingState::ReuseAddress, SharingState::Exclusive) => {
let root_addr = ListenerAddr {
device: None,
ip: ListenerIpAddr { addr: None, identifier: ip.identifier },
};
let conflicts = match device {
None => {
socketmap.descendant_counts(&addr.clone().into()).any(
|(AddrVecTag { has_device: _, sharing: _, state }, _)| match state {
SocketTagState::Conn => false,
SocketTagState::Bound | SocketTagState::Listener => true,
},
) || (addr != &root_addr && socketmap.get(&root_addr.into()).is_some())
}
Some(_) => {
socketmap.descendant_counts(&root_addr.into()).any(
|(AddrVecTag { has_device, sharing: _, state }, _)| match state {
SocketTagState::Conn => false,
SocketTagState::Bound | SocketTagState::Listener => !has_device,
},
)
|| socketmap.descendant_counts(&addr.clone().into()).any(
|(AddrVecTag { has_device: _, sharing: _, state }, _)| match state {
SocketTagState::Conn => false,
SocketTagState::Bound | SocketTagState::Listener => true,
},
)
}
};
if conflicts {
return Err(UpdateSharingError);
}
}
}
Ok(())
}
}
impl<S: SpecSocketId> SocketMapAddrStateUpdateSharingSpec for ListenerAddrState<S> {
fn try_update_sharing(
&mut self,
id: Self::Id,
ListenerSharingState{listening: new_listening, sharing: new_sharing}: &Self::SharingState,
) -> Result<(), IncompatibleError> {
match self {
Self::ExclusiveBound(i) | Self::ExclusiveListener(i) => {
assert_eq!(i, &id);
*self = match new_sharing {
SharingState::Exclusive => match new_listening {
true => Self::ExclusiveListener(id),
false => Self::ExclusiveBound(id),
},
SharingState::ReuseAddress => {
let (listener, bound) = match new_listening {
true => (Some(id), Default::default()),
false => (None, smallvec![id]),
};
Self::Shared { listener, bound }
}
};
Ok(())
}
Self::Shared { listener, bound } => {
if listener.as_ref() == Some(&id) {
match new_sharing {
SharingState::Exclusive => {
if bound.is_empty() {
*self = match new_listening {
true => Self::ExclusiveListener(id),
false => Self::ExclusiveBound(id),
};
Ok(())
} else {
Err(IncompatibleError)
}
}
SharingState::ReuseAddress => match new_listening {
true => Ok(()), false => {
bound.push(id);
*listener = None;
Ok(())
}
},
}
} else {
let index = bound
.iter()
.position(|b| b == &id)
.expect("ID is neither listener nor bound");
if *new_listening && listener.is_some() {
return Err(IncompatibleError);
}
match new_sharing {
SharingState::Exclusive => {
if bound.len() > 1 {
return Err(IncompatibleError);
} else {
*self = match new_listening {
true => Self::ExclusiveListener(id),
false => Self::ExclusiveBound(id),
};
Ok(())
}
}
SharingState::ReuseAddress => {
match new_listening {
false => Ok(()), true => {
let _: S = bound.swap_remove(index);
*listener = Some(id);
Ok(())
}
}
}
}
}
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SharingState {
Exclusive,
ReuseAddress,
}
impl Default for SharingState {
fn default() -> Self {
Self::Exclusive
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
SocketMapConflictPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>,
ListenerSharingState,
I,
D,
TcpPortSpec,
> for TcpSocketSpec<I, D, BT>
{
fn check_insert_conflicts(
sharing: &ListenerSharingState,
addr: &ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>,
socketmap: &SocketMap<AddrVec<I, D, TcpPortSpec>, Bound<Self>>,
) -> Result<(), InsertError> {
let addr = AddrVec::Listen(addr.clone());
let ListenerSharingState { listening: _, sharing } = sharing;
for a in addr.iter_shadows() {
if let Some(s) = socketmap.get(&a) {
match s {
Bound::Conn(c) => unreachable!("found conn state {c:?} at listener addr {a:?}"),
Bound::Listen(l) => match l {
ListenerAddrState::ExclusiveListener(_)
| ListenerAddrState::ExclusiveBound(_) => {
return Err(InsertError::ShadowAddrExists)
}
ListenerAddrState::Shared { listener, bound: _ } => match sharing {
SharingState::Exclusive => return Err(InsertError::ShadowAddrExists),
SharingState::ReuseAddress => match listener {
Some(_) => return Err(InsertError::ShadowAddrExists),
None => (),
},
},
},
}
}
}
for (tag, _count) in socketmap.descendant_counts(&addr) {
let AddrVecTag { sharing: tag_sharing, has_device: _, state: _ } = tag;
match (tag_sharing, sharing) {
(SharingState::Exclusive, SharingState::Exclusive | SharingState::ReuseAddress) => {
return Err(InsertError::ShadowerExists)
}
(SharingState::ReuseAddress, SharingState::Exclusive) => {
return Err(InsertError::ShadowerExists)
}
(SharingState::ReuseAddress, SharingState::ReuseAddress) => (),
}
}
Ok(())
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
SocketMapConflictPolicy<
ConnAddr<ConnIpAddr<I::Addr, NonZeroU16, NonZeroU16>, D>,
SharingState,
I,
D,
TcpPortSpec,
> for TcpSocketSpec<I, D, BT>
{
fn check_insert_conflicts(
_sharing: &SharingState,
addr: &ConnAddr<ConnIpAddr<I::Addr, NonZeroU16, NonZeroU16>, D>,
socketmap: &SocketMap<AddrVec<I, D, TcpPortSpec>, Bound<Self>>,
) -> Result<(), InsertError> {
let addr = AddrVec::Conn(ConnAddr { device: None, ..*addr });
if let Some(_) = socketmap.get(&addr) {
return Err(InsertError::Exists);
}
if socketmap.descendant_counts(&addr).len() > 0 {
return Err(InsertError::ShadowerExists);
}
Ok(())
}
}
#[derive(Debug)]
struct ConnAddrState<S> {
sharing: SharingState,
id: S,
}
impl<S: SpecSocketId> ConnAddrState<S> {
#[cfg_attr(feature = "instrumented", track_caller)]
pub(crate) fn id(&self) -> S {
self.id.clone()
}
}
impl<S: SpecSocketId> SocketMapAddrStateSpec for ConnAddrState<S> {
type Id = S;
type Inserter<'a> = Never;
type SharingState = SharingState;
fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self {
Self { sharing: *new_sharing_state, id }
}
fn contains_id(&self, id: &Self::Id) -> bool {
&self.id == id
}
fn could_insert(
&self,
_new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError> {
Err(IncompatibleError)
}
fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult {
let Self { sharing: _, id: existing_id } = self;
assert_eq!(*existing_id, id);
return RemoveResult::IsLast;
}
fn try_get_inserter<'a, 'b>(
&'b mut self,
_new_sharing_state: &'a Self::SharingState,
) -> Result<Self::Inserter<'b>, IncompatibleError> {
Err(IncompatibleError)
}
}
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub struct Unbound<D, Extra> {
bound_device: Option<D>,
buffer_sizes: BufferSizes,
socket_options: SocketOptions,
sharing: SharingState,
socket_extra: Takeable<Extra>,
}
type ReferenceState<I, D, BT> = RwLock<TcpSocketState<I, D, BT>>;
type PrimaryRc<I, D, BT> = netstack3_base::sync::PrimaryRc<ReferenceState<I, D, BT>>;
type StrongRc<I, D, BT> = netstack3_base::sync::StrongRc<ReferenceState<I, D, BT>>;
type WeakRc<I, D, BT> = netstack3_base::sync::WeakRc<ReferenceState<I, D, BT>>;
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
pub enum TcpSocketSetEntry<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
Primary(PrimaryRc<I, D, BT>),
DeadOnArrival,
}
#[derive(Debug, Derivative)]
#[derivative(Default(bound = ""))]
pub struct TcpSocketSet<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
HashMap<TcpSocketId<I, D, BT>, TcpSocketSetEntry<I, D, BT>>,
);
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Deref
for TcpSocketSet<I, D, BT>
{
type Target = HashMap<TcpSocketId<I, D, BT>, TcpSocketSetEntry<I, D, BT>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> DerefMut
for TcpSocketSet<I, D, BT>
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Drop
for TcpSocketSet<I, D, BT>
{
fn drop(&mut self) {
let Self(map) = self;
for TcpSocketId(rc) in map.keys() {
let guard = rc.read();
let accept_queue = match &(*guard).socket_state {
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(Listener { accept_queue, .. }),
..,
))) => accept_queue,
_ => continue,
};
if !accept_queue.is_closed() {
let (_pending_sockets_iterator, _): (_, BT::ListenerNotifierOrProvidedBuffers) =
accept_queue.close();
}
}
}
}
type BoundSocketMap<I, D, BT> = socket::BoundSocketMap<I, D, TcpPortSpec, TcpSocketSpec<I, D, BT>>;
#[derive(GenericOverIp)]
#[generic_over_ip(I, Ip)]
pub struct DemuxState<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
socketmap: BoundSocketMap<I, D, BT>,
}
pub struct Sockets<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
demux: RwLock<DemuxState<I, D, BT>>,
all_sockets: RwLock<TcpSocketSet<I, D, BT>>,
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
OrderedLockAccess<DemuxState<I, D, BT>> for Sockets<I, D, BT>
{
type Lock = RwLock<DemuxState<I, D, BT>>;
fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
OrderedLockRef::new(&self.demux)
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
OrderedLockAccess<TcpSocketSet<I, D, BT>> for Sockets<I, D, BT>
{
type Lock = RwLock<TcpSocketSet<I, D, BT>>;
fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
OrderedLockRef::new(&self.all_sockets)
}
}
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
pub struct TcpSocketState<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
socket_state: TcpSocketStateInner<I, D, BT>,
ip_options: I::DualStackIpOptions,
}
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
pub enum TcpSocketStateInner<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
Unbound(Unbound<D, BT::ListenerNotifierOrProvidedBuffers>),
Bound(BoundSocketState<I, D, BT>),
}
struct TcpPortAlloc<'a, I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
&'a BoundSocketMap<I, D, BT>,
);
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> PortAllocImpl
for TcpPortAlloc<'_, I, D, BT>
{
const EPHEMERAL_RANGE: RangeInclusive<u16> = 49152..=65535;
type Id = Option<SocketIpAddr<I::Addr>>;
type PortAvailableArg = Option<NonZeroU16>;
fn is_port_available(&self, addr: &Self::Id, port: u16, arg: &Option<NonZeroU16>) -> bool {
let Self(socketmap) = self;
let port = NonZeroU16::new(port).unwrap();
if arg.is_some_and(|a| a == port) {
return false;
}
let root_addr = AddrVec::from(ListenerAddr {
ip: ListenerIpAddr { addr: *addr, identifier: port },
device: None,
});
root_addr.iter_shadows().chain(core::iter::once(root_addr.clone())).all(|a| match &a {
AddrVec::Listen(l) => socketmap.listeners().get_by_addr(&l).is_none(),
AddrVec::Conn(_c) => {
unreachable!("no connection shall be included in an iteration from a listener")
}
}) && socketmap.get_shadower_counts(&root_addr) == 0
}
}
struct TcpDualStackPortAlloc<'a, I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
&'a BoundSocketMap<I, D, BT>,
&'a BoundSocketMap<I::OtherVersion, D, BT>,
);
impl<'a, I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> PortAllocImpl
for TcpDualStackPortAlloc<'a, I, D, BT>
{
const EPHEMERAL_RANGE: RangeInclusive<u16> =
<TcpPortAlloc<'a, I, D, BT> as PortAllocImpl>::EPHEMERAL_RANGE;
type Id = ();
type PortAvailableArg = ();
fn is_port_available(&self, (): &Self::Id, port: u16, (): &Self::PortAvailableArg) -> bool {
let Self(this, other) = self;
TcpPortAlloc(this).is_port_available(&None, port, &None)
&& TcpPortAlloc(other).is_port_available(&None, port, &None)
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Sockets<I, D, BT> {
pub(crate) fn new() -> Self {
Self {
demux: RwLock::new(DemuxState { socketmap: Default::default() }),
all_sockets: Default::default(),
}
}
}
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
pub struct Connection<
SockI: DualStackIpExt,
WireI: DualStackIpExt,
D: WeakDeviceIdentifier,
BT: TcpBindingsTypes,
> {
accept_queue: Option<
AcceptQueue<
TcpSocketId<SockI, D, BT>,
BT::ReturnedBuffers,
BT::ListenerNotifierOrProvidedBuffers,
>,
>,
state: State<
BT::Instant,
BT::ReceiveBuffer,
BT::SendBuffer,
BT::ListenerNotifierOrProvidedBuffers,
>,
ip_sock: IpSock<WireI, D>,
defunct: bool,
socket_options: SocketOptions,
soft_error: Option<ConnectionError>,
handshake_status: HandshakeStatus,
}
impl<
SockI: DualStackIpExt,
WireI: DualStackIpExt,
D: WeakDeviceIdentifier,
BT: TcpBindingsTypes,
> Connection<SockI, WireI, D, BT>
{
fn on_icmp_error<CC: CounterContext<TcpCounters<SockI>>>(
&mut self,
core_ctx: &mut CC,
seq: SeqNum,
error: IcmpErrorCode,
) -> NewlyClosed {
let Connection { soft_error, state, .. } = self;
let (new_soft_error, newly_closed) =
core_ctx.with_counters(|counters| state.on_icmp_error(counters, error, seq));
*soft_error = soft_error.or(new_soft_error);
newly_closed
}
}
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
#[cfg_attr(
test,
derivative(
PartialEq(
bound = "BT::ReturnedBuffers: PartialEq, BT::ListenerNotifierOrProvidedBuffers: PartialEq"
),
Eq(bound = "BT::ReturnedBuffers: Eq, BT::ListenerNotifierOrProvidedBuffers: Eq"),
)
)]
pub struct Listener<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
backlog: NonZeroUsize,
accept_queue: AcceptQueue<
TcpSocketId<I, D, BT>,
BT::ReturnedBuffers,
BT::ListenerNotifierOrProvidedBuffers,
>,
buffer_sizes: BufferSizes,
socket_options: SocketOptions,
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Listener<I, D, BT> {
fn new(
backlog: NonZeroUsize,
buffer_sizes: BufferSizes,
socket_options: SocketOptions,
notifier: BT::ListenerNotifierOrProvidedBuffers,
) -> Self {
Self { backlog, accept_queue: AcceptQueue::new(notifier), buffer_sizes, socket_options }
}
}
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(Eq, PartialEq))]
pub struct BoundState<Extra> {
buffer_sizes: BufferSizes,
socket_options: SocketOptions,
socket_extra: Takeable<Extra>,
}
#[derive(Derivative)]
#[derivative(Debug(bound = "D: Debug"))]
#[cfg_attr(
test,
derivative(
Eq(bound = "BT::ReturnedBuffers: Eq, BT::ListenerNotifierOrProvidedBuffers: Eq"),
PartialEq(
bound = "BT::ReturnedBuffers: PartialEq, BT::ListenerNotifierOrProvidedBuffers: PartialEq"
)
)
)]
pub enum MaybeListener<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> {
Bound(BoundState<BT::ListenerNotifierOrProvidedBuffers>),
Listener(Listener<I, D, BT>),
}
#[derive(Derivative, GenericOverIp)]
#[generic_over_ip(I, Ip)]
#[derivative(Eq(bound = ""), PartialEq(bound = ""), Hash(bound = ""))]
pub struct TcpSocketId<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
StrongRc<I, D, BT>,
);
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Clone
for TcpSocketId<I, D, BT>
{
#[cfg_attr(feature = "instrumented", track_caller)]
fn clone(&self) -> Self {
let Self(rc) = self;
Self(StrongRc::clone(rc))
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> TcpSocketId<I, D, BT> {
pub(crate) fn new(socket_state: TcpSocketStateInner<I, D, BT>) -> (Self, PrimaryRc<I, D, BT>) {
let primary = PrimaryRc::new(RwLock::new(TcpSocketState {
socket_state,
ip_options: Default::default(),
}));
let socket = Self(PrimaryRc::clone_strong(&primary));
(socket, primary)
}
pub(crate) fn new_cyclic<
F: FnOnce(WeakTcpSocketId<I, D, BT>) -> TcpSocketStateInner<I, D, BT>,
>(
init: F,
) -> (Self, PrimaryRc<I, D, BT>) {
let primary = PrimaryRc::new_cyclic(move |weak| {
let socket_state = init(WeakTcpSocketId(weak));
RwLock::new(TcpSocketState { socket_state, ip_options: Default::default() })
});
let socket = Self(PrimaryRc::clone_strong(&primary));
(socket, primary)
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Debug
for TcpSocketId<I, D, BT>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let Self(rc) = self;
f.debug_tuple("TcpSocketId").field(&StrongRc::debug_id(rc)).finish()
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> TcpSocketId<I, D, BT> {
pub(crate) fn downgrade(&self) -> WeakTcpSocketId<I, D, BT> {
let Self(this) = self;
WeakTcpSocketId(StrongRc::downgrade(this))
}
}
#[derive(Derivative, GenericOverIp)]
#[generic_over_ip(I, Ip)]
#[derivative(Clone(bound = ""), Eq(bound = ""), PartialEq(bound = ""), Hash(bound = ""))]
pub struct WeakTcpSocketId<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
WeakRc<I, D, BT>,
);
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> Debug
for WeakTcpSocketId<I, D, BT>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let Self(rc) = self;
f.debug_tuple("WeakTcpSocketId").field(&rc.debug_id()).finish()
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
PartialEq<TcpSocketId<I, D, BT>> for WeakTcpSocketId<I, D, BT>
{
fn eq(&self, other: &TcpSocketId<I, D, BT>) -> bool {
let Self(this) = self;
let TcpSocketId(other) = other;
StrongRc::weak_ptr_eq(other, this)
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> WeakTcpSocketId<I, D, BT> {
#[cfg_attr(feature = "instrumented", track_caller)]
pub(crate) fn upgrade(&self) -> Option<TcpSocketId<I, D, BT>> {
let Self(this) = self;
this.upgrade().map(TcpSocketId)
}
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>
OrderedLockAccess<TcpSocketState<I, D, BT>> for TcpSocketId<I, D, BT>
{
type Lock = RwLock<TcpSocketState<I, D, BT>>;
fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
let Self(rc) = self;
OrderedLockRef::new(&*rc)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum HandshakeStatus {
Pending,
Aborted,
Completed {
reported: bool,
},
}
impl HandshakeStatus {
fn update_if_pending(&mut self, new_status: Self) -> bool {
if *self == HandshakeStatus::Pending {
*self = new_status;
true
} else {
false
}
}
}
fn bind_get_local_addr_and_device<I, BT, CC>(
core_ctx: &mut CC,
addr: Option<ZonedAddr<SocketIpAddr<I::Addr>, CC::DeviceId>>,
bound_device: &Option<CC::WeakDeviceId>,
) -> Result<(Option<SocketIpAddr<I::Addr>>, Option<CC::WeakDeviceId>), LocalAddressError>
where
I: DualStackIpExt,
BT: TcpBindingsTypes,
CC: TransportIpContext<I, BT>,
{
let (local_ip, device) = match addr {
Some(addr) => {
let (addr, required_device) = addr
.resolve_addr_with_device(bound_device.clone())
.map_err(LocalAddressError::Zone)?;
core_ctx.with_devices_with_assigned_addr(addr.clone().into(), |mut assigned_to| {
if !assigned_to.any(|d| {
required_device
.as_ref()
.map_or(true, |device| device == &EitherDeviceId::Strong(d))
}) {
Err(LocalAddressError::AddressMismatch)
} else {
Ok(())
}
})?;
(Some(addr), required_device)
}
None => (None, bound_device.clone().map(EitherDeviceId::Weak)),
};
let weak_device = device.map(|d| d.as_weak().into_owned());
Ok((local_ip, weak_device))
}
fn bind_install_in_demux<I, D, BC>(
bindings_ctx: &mut BC,
demux_socket_id: I::DemuxSocketId<D, BC>,
local_ip: Option<SocketIpAddr<I::Addr>>,
weak_device: Option<D>,
port: Option<NonZeroU16>,
sharing: SharingState,
DemuxState { socketmap }: &mut DemuxState<I, D, BC>,
) -> Result<
(ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, D>, ListenerSharingState),
LocalAddressError,
>
where
I: DualStackIpExt,
BC: TcpBindingsTypes + RngContext,
D: WeakDeviceIdentifier,
{
let port = match port {
None => {
match netstack3_base::simple_randomized_port_alloc(
&mut bindings_ctx.rng(),
&local_ip,
&TcpPortAlloc(socketmap),
&None,
) {
Some(port) => NonZeroU16::new(port).expect("ephemeral ports must be non-zero"),
None => {
return Err(LocalAddressError::FailedToAllocateLocalPort);
}
}
}
Some(port) => port,
};
let addr = ListenerAddr {
ip: ListenerIpAddr { addr: local_ip, identifier: port },
device: weak_device,
};
let sharing = ListenerSharingState { sharing, listening: false };
let _inserted = socketmap
.listeners_mut()
.try_insert(addr.clone(), sharing.clone(), demux_socket_id)
.map_err(|_: (InsertError, ListenerSharingState)| LocalAddressError::AddressInUse)?;
Ok((addr, sharing))
}
fn try_update_listener_sharing<I, CC, BT>(
core_ctx: MaybeDualStack<
(&mut CC::DualStackIpTransportAndDemuxCtx<'_>, CC::DualStackConverter),
(&mut CC::SingleStackIpTransportAndDemuxCtx<'_>, CC::SingleStackConverter),
>,
id: &TcpSocketId<I, CC::WeakDeviceId, BT>,
addr: ListenerAddr<I::ListenerIpAddr, CC::WeakDeviceId>,
sharing: &ListenerSharingState,
new_sharing: ListenerSharingState,
) -> Result<ListenerSharingState, UpdateSharingError>
where
I: DualStackIpExt,
CC: TcpContext<I, BT>,
BT: TcpBindingsTypes,
{
match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
core_ctx.with_demux_mut(|DemuxState { socketmap }| {
let mut entry = socketmap
.listeners_mut()
.entry(&I::into_demux_socket_id(id.clone()), &converter.convert(addr))
.expect("invalid listener id");
entry.try_update_sharing(sharing, new_sharing)
})
}
MaybeDualStack::DualStack((core_ctx, converter)) => match converter.convert(addr) {
ListenerAddr { ip: DualStackListenerIpAddr::ThisStack(ip), device } => {
TcpDemuxContext::<I, _, _>::with_demux_mut(core_ctx, |DemuxState { socketmap }| {
let mut entry = socketmap
.listeners_mut()
.entry(&I::into_demux_socket_id(id.clone()), &ListenerAddr { ip, device })
.expect("invalid listener id");
entry.try_update_sharing(sharing, new_sharing)
})
}
ListenerAddr { ip: DualStackListenerIpAddr::OtherStack(ip), device } => {
let demux_id = core_ctx.into_other_demux_socket_id(id.clone());
TcpDemuxContext::<I::OtherVersion, _, _>::with_demux_mut(
core_ctx,
|DemuxState { socketmap }| {
let mut entry = socketmap
.listeners_mut()
.entry(&demux_id, &ListenerAddr { ip, device })
.expect("invalid listener id");
entry.try_update_sharing(sharing, new_sharing)
},
)
}
ListenerAddr { ip: DualStackListenerIpAddr::BothStacks(port), device } => {
let other_demux_id = core_ctx.into_other_demux_socket_id(id.clone());
let demux_id = I::into_demux_socket_id(id.clone());
core_ctx.with_both_demux_mut(
|DemuxState { socketmap: this_socketmap, .. },
DemuxState { socketmap: other_socketmap, .. }| {
let this_stack_listener_addr = ListenerAddr {
ip: ListenerIpAddr { addr: None, identifier: port },
device: device.clone(),
};
let mut this_stack_entry = this_socketmap
.listeners_mut()
.entry(&demux_id, &this_stack_listener_addr)
.expect("invalid listener id");
this_stack_entry.try_update_sharing(sharing, new_sharing)?;
let mut other_stack_entry = other_socketmap
.listeners_mut()
.entry(
&other_demux_id,
&ListenerAddr {
ip: ListenerIpAddr { addr: None, identifier: port },
device,
},
)
.expect("invalid listener id");
match other_stack_entry.try_update_sharing(sharing, new_sharing) {
Ok(()) => Ok(()),
Err(err) => {
this_stack_entry
.try_update_sharing(&new_sharing, *sharing)
.expect("failed to revert the sharing setting");
Err(err)
}
}
},
)
}
},
}?;
Ok(new_sharing)
}
pub struct TcpApi<I: Ip, C>(C, IpVersionMarker<I>);
impl<I: Ip, C> TcpApi<I, C> {
pub fn new(ctx: C) -> Self {
Self(ctx, IpVersionMarker::new())
}
}
type TcpApiSocketId<I, C> = TcpSocketId<
I,
<<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
<C as ContextPair>::BindingsContext,
>;
impl<I, C> TcpApi<I, C>
where
I: DualStackIpExt,
C: ContextPair,
C::CoreContext: TcpContext<I, C::BindingsContext> + CounterContext<TcpCounters<I>>,
C::BindingsContext: TcpBindingsContext,
{
fn core_ctx(&mut self) -> &mut C::CoreContext {
let Self(pair, IpVersionMarker { .. }) = self;
pair.core_ctx()
}
fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
let Self(pair, IpVersionMarker { .. }) = self;
pair.contexts()
}
pub fn create(
&mut self,
socket_extra: <C::BindingsContext as TcpBindingsTypes>::ListenerNotifierOrProvidedBuffers,
) -> TcpApiSocketId<I, C> {
self.core_ctx().with_all_sockets_mut(|all_sockets| {
let (sock, primary) = TcpSocketId::new(TcpSocketStateInner::Unbound(Unbound {
bound_device: Default::default(),
buffer_sizes: C::BindingsContext::default_buffer_sizes(),
sharing: Default::default(),
socket_options: Default::default(),
socket_extra: Takeable::new(socket_extra),
}));
assert_matches::assert_matches!(
all_sockets.insert(sock.clone(), TcpSocketSetEntry::Primary(primary)),
None
);
sock
})
}
pub fn bind(
&mut self,
id: &TcpApiSocketId<I, C>,
addr: Option<
ZonedAddr<
SpecifiedAddr<I::Addr>,
<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId,
>,
>,
port: Option<NonZeroU16>,
) -> Result<(), BindError> {
#[derive(GenericOverIp)]
#[generic_over_ip(I, Ip)]
enum BindAddr<I: DualStackIpExt, D> {
BindInBothStacks,
BindInOneStack(
EitherStack<
Option<ZonedAddr<SocketIpAddr<I::Addr>, D>>,
Option<ZonedAddr<SocketIpAddr<<I::OtherVersion as Ip>::Addr>, D>>,
>,
),
}
debug!("bind {id:?} to {addr:?}:{port:?}");
let bind_addr = match addr {
None => I::map_ip(
(),
|()| BindAddr::BindInOneStack(EitherStack::ThisStack(None)),
|()| BindAddr::BindInBothStacks,
),
Some(addr) => match DualStackLocalIp::<I, _>::new(addr) {
DualStackLocalIp::ThisStack(addr) => {
BindAddr::BindInOneStack(EitherStack::ThisStack(Some(addr)))
}
DualStackLocalIp::OtherStack(addr) => {
BindAddr::BindInOneStack(EitherStack::OtherStack(addr))
}
},
};
let (core_ctx, bindings_ctx) = self.contexts();
let result = core_ctx.with_socket_mut_transport_demux(id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options } = socket_state;
let Unbound { bound_device, buffer_sizes, socket_options, sharing, socket_extra } =
match socket_state {
TcpSocketStateInner::Unbound(u) => u,
TcpSocketStateInner::Bound(_) => return Err(BindError::AlreadyBound),
};
let (listener_addr, sharing) = match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => match bind_addr {
BindAddr::BindInOneStack(EitherStack::ThisStack(local_addr)) => {
let (local_addr, device) = bind_get_local_addr_and_device(core_ctx, local_addr, bound_device)?;
let (addr, sharing) =
core_ctx.with_demux_mut(|demux| {
bind_install_in_demux(
bindings_ctx,
I::into_demux_socket_id(id.clone()),
local_addr,
device,
port,
*sharing,
demux,
)
})?;
(converter.convert_back(addr), sharing)
}
BindAddr::BindInOneStack(EitherStack::OtherStack(_)) | BindAddr::BindInBothStacks => {
return Err(LocalAddressError::CannotBindToAddress.into());
}
},
MaybeDualStack::DualStack((core_ctx, converter)) => {
let bind_addr = match (
core_ctx.dual_stack_enabled(&ip_options),
bind_addr
) {
(true, BindAddr::BindInBothStacks)
=> BindAddr::<I, _>::BindInBothStacks,
(false, BindAddr::BindInBothStacks)
=> BindAddr::BindInOneStack(EitherStack::ThisStack(None)),
(true | false, BindAddr::BindInOneStack(EitherStack::ThisStack(ip)))
=> BindAddr::BindInOneStack(EitherStack::ThisStack(ip)),
(true, BindAddr::BindInOneStack(EitherStack::OtherStack(ip)))
=> BindAddr::BindInOneStack(EitherStack::OtherStack(ip)),
(false, BindAddr::BindInOneStack(EitherStack::OtherStack(_)))
=> return Err(LocalAddressError::CannotBindToAddress.into()),
};
match bind_addr {
BindAddr::BindInOneStack(EitherStack::ThisStack(addr)) => {
let (addr, device) = bind_get_local_addr_and_device::<I, _, _>(core_ctx, addr, bound_device)?;
let (ListenerAddr { ip, device }, sharing) =
core_ctx.with_demux_mut(|demux: &mut DemuxState<I, _, _>| {
bind_install_in_demux(
bindings_ctx,
I::into_demux_socket_id(id.clone()),
addr,
device,
port,
*sharing,
demux,
)
})?;
(
converter.convert_back(ListenerAddr {
ip: DualStackListenerIpAddr::ThisStack(ip),
device,
}),
sharing,
)
}
BindAddr::BindInOneStack(EitherStack::OtherStack(addr)) => {
let other_demux_id = core_ctx.into_other_demux_socket_id(id.clone());
let (addr, device) = bind_get_local_addr_and_device::<I::OtherVersion, _, _>(core_ctx, addr, bound_device)?;
let (ListenerAddr { ip, device }, sharing) =
core_ctx.with_demux_mut(|demux: &mut DemuxState<I::OtherVersion, _, _>| {
bind_install_in_demux(
bindings_ctx,
other_demux_id,
addr,
device,
port,
*sharing,
demux,
)
})?;
(
converter.convert_back(ListenerAddr {
ip: DualStackListenerIpAddr::OtherStack(ip),
device,
}),
sharing,
)
}
BindAddr::BindInBothStacks => {
let other_demux_id = core_ctx.into_other_demux_socket_id(id.clone());
let (port, device, sharing) =
core_ctx.with_both_demux_mut(|demux, other_demux| {
let port_alloc = TcpDualStackPortAlloc(
&demux.socketmap,
&other_demux.socketmap
);
let port = match port {
Some(port) => port,
None => match netstack3_base::simple_randomized_port_alloc(
&mut bindings_ctx.rng(),
&(),
&port_alloc,
&(),
){
Some(port) => NonZeroU16::new(port)
.expect("ephemeral ports must be non-zero"),
None => {
return Err(LocalAddressError::FailedToAllocateLocalPort);
}
}
};
let (this_stack_addr, this_stack_sharing) = bind_install_in_demux(
bindings_ctx,
I::into_demux_socket_id(id.clone()),
None,
bound_device.clone(),
Some(port),
*sharing,
demux,
)?;
match bind_install_in_demux(
bindings_ctx,
other_demux_id,
None,
bound_device.clone(),
Some(port),
*sharing,
other_demux,
) {
Ok((ListenerAddr { ip, device }, other_stack_sharing)) => {
assert_eq!(this_stack_addr.ip.identifier, ip.identifier);
assert_eq!(this_stack_sharing, other_stack_sharing);
Ok((port, device, this_stack_sharing))
}
Err(err) => {
demux.socketmap.listeners_mut().remove(&I::into_demux_socket_id(id.clone()), &this_stack_addr).expect("failed to unbind");
Err(err)
}
}
})?;
(
ListenerAddr {
ip: converter.convert_back(DualStackListenerIpAddr::BothStacks(port)),
device,
},
sharing,
)
}
}
},
};
let bound_state = BoundState {
buffer_sizes: buffer_sizes.clone(),
socket_options: socket_options.clone(),
socket_extra: Takeable::from_ref(socket_extra.to_ref()),
};
*socket_state = TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Bound(bound_state),
sharing,
listener_addr,
)));
Ok(())
});
match &result {
Err(BindError::LocalAddressError(LocalAddressError::FailedToAllocateLocalPort)) => {
core_ctx.increment(|counters| &counters.failed_port_reservations);
}
Err(_) | Ok(_) => {}
}
result
}
pub fn listen(
&mut self,
id: &TcpApiSocketId<I, C>,
backlog: NonZeroUsize,
) -> Result<(), ListenError> {
debug!("listen on {id:?} with backlog {backlog}");
self.core_ctx().with_socket_mut_transport_demux(id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
let (listener, listener_sharing, addr) = match socket_state {
TcpSocketStateInner::Bound(BoundSocketState::Listener((l, sharing, addr))) => {
match l {
MaybeListener::Listener(_) => return Err(ListenError::NotSupported),
MaybeListener::Bound(_) => (l, sharing, addr),
}
}
TcpSocketStateInner::Bound(BoundSocketState::Connected { .. })
| TcpSocketStateInner::Unbound(_) => return Err(ListenError::NotSupported),
};
let new_sharing = {
let ListenerSharingState { sharing, listening } = listener_sharing;
debug_assert!(!*listening, "invalid bound ID that has a listener socket");
ListenerSharingState { sharing: *sharing, listening: true }
};
*listener_sharing = try_update_listener_sharing::<_, C::CoreContext, _>(
core_ctx,
id,
addr.clone(),
listener_sharing,
new_sharing,
)
.map_err(|UpdateSharingError| ListenError::ListenerExists)?;
match listener {
MaybeListener::Bound(BoundState { buffer_sizes, socket_options, socket_extra }) => {
*listener = MaybeListener::Listener(Listener::new(
backlog,
buffer_sizes.clone(),
socket_options.clone(),
socket_extra.to_ref().take(),
));
}
MaybeListener::Listener(_) => {
unreachable!("invalid bound id that points to a listener entry")
}
}
Ok(())
})
}
pub fn accept(
&mut self,
id: &TcpApiSocketId<I, C>,
) -> Result<
(
TcpApiSocketId<I, C>,
SocketAddr<I::Addr, <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
<C::BindingsContext as TcpBindingsTypes>::ReturnedBuffers,
),
AcceptError,
> {
let (conn_id, client_buffers) = self.core_ctx().with_socket_mut(id, |socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
debug!("accept on {id:?}");
let Listener { backlog: _, buffer_sizes: _, socket_options: _, accept_queue } =
match socket_state {
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(l),
_sharing,
_addr,
))) => l,
TcpSocketStateInner::Unbound(_)
| TcpSocketStateInner::Bound(BoundSocketState::Connected { .. })
| TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Bound(_),
_,
_,
))) => return Err(AcceptError::NotSupported),
};
let (conn_id, client_buffers) =
accept_queue.pop_ready().ok_or(AcceptError::WouldBlock)?;
Ok::<_, AcceptError>((conn_id, client_buffers))
})?;
let remote_addr =
self.core_ctx().with_socket_mut_and_converter(&conn_id, |socket_state, _converter| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
let conn_and_addr = assert_matches!(
socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected{ conn, .. }) => conn,
"invalid socket ID"
);
*I::get_accept_queue_mut(conn_and_addr) = None;
let ConnectionInfo { local_addr: _, remote_addr, device: _ } =
I::get_conn_info(conn_and_addr);
remote_addr
});
debug!("accepted connection {conn_id:?} from {remote_addr:?} on {id:?}");
Ok((conn_id, remote_addr, client_buffers))
}
pub fn connect(
&mut self,
id: &TcpApiSocketId<I, C>,
remote_ip: Option<
ZonedAddr<
SpecifiedAddr<I::Addr>,
<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId,
>,
>,
remote_port: NonZeroU16,
) -> Result<(), ConnectError> {
let (core_ctx, bindings_ctx) = self.contexts();
let result =
core_ctx.with_socket_mut_isn_transport_demux(id, |core_ctx, socket_state, isn| {
let TcpSocketState { socket_state, ip_options } = socket_state;
debug!("connect on {id:?} to {remote_ip:?}:{remote_port}");
let remote_ip = DualStackRemoteIp::<I, _>::new(remote_ip);
let (local_addr, sharing, socket_options, buffer_sizes, socket_extra) =
match socket_state {
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn,
sharing: _,
timer: _,
}) => {
let handshake_status = match core_ctx {
MaybeDualStack::NotDualStack((_core_ctx, converter)) => {
let (conn, _addr) = converter.convert(conn);
&mut conn.handshake_status
}
MaybeDualStack::DualStack((_core_ctx, converter)) => {
match converter.convert(conn) {
EitherStack::ThisStack((conn, _addr)) => {
&mut conn.handshake_status
}
EitherStack::OtherStack((conn, _addr)) => {
&mut conn.handshake_status
}
}
}
};
match handshake_status {
HandshakeStatus::Pending => return Err(ConnectError::Pending),
HandshakeStatus::Aborted => return Err(ConnectError::Aborted),
HandshakeStatus::Completed { reported } => {
if *reported {
return Err(ConnectError::Completed);
} else {
*reported = true;
return Ok(());
}
}
}
}
TcpSocketStateInner::Unbound(Unbound {
bound_device: _,
socket_extra,
buffer_sizes,
socket_options,
sharing,
}) => (
DualStackTuple::<I, _>::new(None, None),
*sharing,
*socket_options,
*buffer_sizes,
socket_extra.to_ref(),
),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
listener,
ListenerSharingState { sharing, listening: _ },
addr,
))) => {
let local_addr = match &core_ctx {
MaybeDualStack::DualStack((_core_ctx, converter)) => {
match converter.convert(addr.clone()) {
ListenerAddr {
ip: DualStackListenerIpAddr::ThisStack(ip),
device,
} => DualStackTuple::new(
Some(ListenerAddr { ip, device }),
None,
),
ListenerAddr {
ip: DualStackListenerIpAddr::OtherStack(ip),
device,
} => DualStackTuple::new(
None,
Some(ListenerAddr { ip, device }),
),
ListenerAddr {
ip: DualStackListenerIpAddr::BothStacks(port),
device,
} => DualStackTuple::new(
Some(ListenerAddr {
ip: ListenerIpAddr { addr: None, identifier: port },
device: device.clone(),
}),
Some(ListenerAddr {
ip: ListenerIpAddr { addr: None, identifier: port },
device,
}),
),
}
}
MaybeDualStack::NotDualStack((_core_ctx, converter)) => {
DualStackTuple::new(Some(converter.convert(addr.clone())), None)
}
};
match listener {
MaybeListener::Bound(BoundState {
buffer_sizes,
socket_options,
socket_extra,
}) => (
local_addr,
*sharing,
*socket_options,
*buffer_sizes,
socket_extra.to_ref(),
),
MaybeListener::Listener(_) => return Err(ConnectError::Listener),
}
}
};
let local_addr = local_addr.into_inner();
match (core_ctx, local_addr, remote_ip) {
(
MaybeDualStack::NotDualStack((core_ctx, converter)),
(local_addr_this_stack, None),
DualStackRemoteIp::ThisStack(remote_ip),
) => {
*socket_state = connect_inner(
core_ctx,
bindings_ctx,
id,
isn,
local_addr_this_stack.clone(),
remote_ip,
remote_port,
socket_extra,
buffer_sizes,
socket_options,
sharing,
SingleStackDemuxStateAccessor(
&I::into_demux_socket_id(id.clone()),
local_addr_this_stack,
),
|conn, addr| converter.convert_back((conn, addr)),
<C::CoreContext as CoreTimerContext<_, _>>::convert_timer,
)?;
Ok(())
}
(
MaybeDualStack::DualStack((core_ctx, converter)),
(local_addr_this_stack, local_addr_other_stack @ None)
| (local_addr_this_stack @ Some(_), local_addr_other_stack @ Some(_)),
DualStackRemoteIp::ThisStack(remote_ip),
) => {
*socket_state = connect_inner(
core_ctx,
bindings_ctx,
id,
isn,
local_addr_this_stack.clone(),
remote_ip,
remote_port,
socket_extra,
buffer_sizes,
socket_options,
sharing,
DualStackDemuxStateAccessor(
id,
DualStackTuple::new(local_addr_this_stack, local_addr_other_stack),
),
|conn, addr| {
converter.convert_back(EitherStack::ThisStack((conn, addr)))
},
<C::CoreContext as CoreTimerContext<_, _>>::convert_timer,
)?;
Ok(())
}
(
MaybeDualStack::DualStack((core_ctx, converter)),
(local_addr_this_stack @ None, local_addr_other_stack)
| (local_addr_this_stack @ Some(_), local_addr_other_stack @ Some(_)),
DualStackRemoteIp::OtherStack(remote_ip),
) => {
if !core_ctx.dual_stack_enabled(ip_options) {
return Err(ConnectError::NoRoute);
}
*socket_state = connect_inner(
core_ctx,
bindings_ctx,
id,
isn,
local_addr_other_stack.clone(),
remote_ip,
remote_port,
socket_extra,
buffer_sizes,
socket_options,
sharing,
DualStackDemuxStateAccessor(
id,
DualStackTuple::new(local_addr_this_stack, local_addr_other_stack),
),
|conn, addr| {
converter.convert_back(EitherStack::OtherStack((conn, addr)))
},
<C::CoreContext as CoreTimerContext<_, _>>::convert_timer,
)?;
Ok(())
}
(
MaybeDualStack::NotDualStack(_),
(_, Some(_other_stack_local_addr)),
DualStackRemoteIp::ThisStack(_) | DualStackRemoteIp::OtherStack(_),
) => unreachable!("The socket cannot be bound in the other stack"),
(
MaybeDualStack::DualStack(_),
(_, Some(_other_stack_local_addr)),
DualStackRemoteIp::ThisStack(_),
) => Err(ConnectError::NoRoute),
(
MaybeDualStack::DualStack(_) | MaybeDualStack::NotDualStack(_),
(Some(_this_stack_local_addr), _),
DualStackRemoteIp::OtherStack(_),
) => Err(ConnectError::NoRoute),
(
MaybeDualStack::NotDualStack(_),
(None, None),
DualStackRemoteIp::OtherStack(_),
) => Err(ConnectError::NoRoute),
}
});
match &result {
Ok(()) => {}
Err(err) => {
core_ctx.increment(|counters| &counters.failed_connection_attempts);
match err {
ConnectError::NoRoute => {
core_ctx.increment(|counters| &counters.active_open_no_route_errors)
}
ConnectError::NoPort => {
core_ctx.increment(|counters| &counters.failed_port_reservations)
}
_ => {}
}
}
}
result
}
pub fn close(&mut self, id: TcpApiSocketId<I, C>) {
debug!("close on {id:?}");
let (core_ctx, bindings_ctx) = self.contexts();
let (destroy, pending) =
core_ctx.with_socket_mut_transport_demux(&id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(_) => (true, None),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
maybe_listener,
_sharing,
addr,
))) => {
match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
TcpDemuxContext::<I, _, _>::with_demux_mut(
core_ctx,
|DemuxState { socketmap }| {
socketmap
.listeners_mut()
.remove(
&I::into_demux_socket_id(id.clone()),
&converter.convert(addr),
)
.expect("failed to remove from socketmap");
},
);
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(addr.clone()) {
ListenerAddr {
ip: DualStackListenerIpAddr::ThisStack(ip),
device,
} => TcpDemuxContext::<I, _, _>::with_demux_mut(
core_ctx,
|DemuxState { socketmap }| {
socketmap
.listeners_mut()
.remove(
&I::into_demux_socket_id(id.clone()),
&ListenerAddr { ip, device },
)
.expect("failed to remove from socketmap");
},
),
ListenerAddr {
ip: DualStackListenerIpAddr::OtherStack(ip),
device,
} => {
let other_demux_id =
core_ctx.into_other_demux_socket_id(id.clone());
TcpDemuxContext::<I::OtherVersion, _, _>::with_demux_mut(
core_ctx,
|DemuxState { socketmap }| {
socketmap
.listeners_mut()
.remove(
&other_demux_id,
&ListenerAddr { ip, device },
)
.expect("failed to remove from socketmap");
},
);
}
ListenerAddr {
ip: DualStackListenerIpAddr::BothStacks(port),
device,
} => {
let other_demux_id =
core_ctx.into_other_demux_socket_id(id.clone());
core_ctx.with_both_demux_mut(|demux, other_demux| {
demux
.socketmap
.listeners_mut()
.remove(
&I::into_demux_socket_id(id.clone()),
&ListenerAddr {
ip: ListenerIpAddr {
addr: None,
identifier: port,
},
device: device.clone(),
},
)
.expect("failed to remove from socketmap");
other_demux
.socketmap
.listeners_mut()
.remove(
&other_demux_id,
&ListenerAddr {
ip: ListenerIpAddr {
addr: None,
identifier: port,
},
device,
},
)
.expect("failed to remove from socketmap");
});
}
}
}
};
let pending =
replace_with::replace_with_and(maybe_listener, |maybe_listener| {
match maybe_listener {
MaybeListener::Bound(b) => (MaybeListener::Bound(b), None),
MaybeListener::Listener(listener) => {
let Listener {
backlog: _,
accept_queue,
buffer_sizes,
socket_options,
} = listener;
let (pending, socket_extra) = accept_queue.close();
let bound_state = BoundState {
buffer_sizes,
socket_options,
socket_extra: Takeable::new(socket_extra),
};
(MaybeListener::Bound(bound_state), Some(pending))
}
}
});
(true, pending)
}
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn,
sharing: _,
timer,
}) => {
fn do_close<SockI, WireI, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
id: &TcpSocketId<SockI, CC::WeakDeviceId, BC>,
demux_id: &WireI::DemuxSocketId<CC::WeakDeviceId, BC>,
conn: &mut Connection<SockI, WireI, CC::WeakDeviceId, BC>,
addr: &ConnAddr<
ConnIpAddr<<WireI as Ip>::Addr, NonZeroU16, NonZeroU16>,
CC::WeakDeviceId,
>,
timer: &mut BC::Timer,
) -> bool
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BC: TcpBindingsContext,
CC: TransportIpContext<WireI, BC>
+ TcpDemuxContext<WireI, CC::WeakDeviceId, BC>
+ CounterContext<TcpCounters<SockI>>,
{
let _: Result<(), CloseError> = conn.state.shutdown_recv();
conn.defunct = true;
let newly_closed = match core_ctx.with_counters(|counters| {
conn.state.close(
counters,
CloseReason::Close { now: bindings_ctx.now() },
&conn.socket_options,
)
}) {
Err(CloseError::NoConnection) => NewlyClosed::No,
Err(CloseError::Closing) | Ok(NewlyClosed::No) => {
do_send_inner(&id, conn, &addr, timer, core_ctx, bindings_ctx)
}
Ok(NewlyClosed::Yes) => NewlyClosed::Yes,
};
handle_newly_closed(
core_ctx,
bindings_ctx,
newly_closed,
demux_id,
addr,
timer,
);
let now_closed = matches!(conn.state, State::Closed(_));
if now_closed {
debug_assert!(
core_ctx.with_demux_mut(|DemuxState { socketmap }| {
socketmap.conns_mut().entry(demux_id, addr).is_none()
}),
"lingering state in socketmap: demux_id: {:?}, addr: {:?}",
demux_id,
addr,
);
debug_assert_eq!(
bindings_ctx.scheduled_instant(timer),
None,
"lingering timer for {:?}",
id,
)
};
now_closed
}
let closed = match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn);
do_close(
core_ctx,
bindings_ctx,
&id,
&I::into_demux_socket_id(id.clone()),
conn,
addr,
timer,
)
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn) {
EitherStack::ThisStack((conn, addr)) => do_close(
core_ctx,
bindings_ctx,
&id,
&I::into_demux_socket_id(id.clone()),
conn,
addr,
timer,
),
EitherStack::OtherStack((conn, addr)) => do_close(
core_ctx,
bindings_ctx,
&id,
&core_ctx.into_other_demux_socket_id(id.clone()),
conn,
addr,
timer,
),
}
}
};
(closed, None)
}
}
});
close_pending_sockets(core_ctx, bindings_ctx, pending.into_iter().flatten());
if destroy {
destroy_socket(core_ctx, bindings_ctx, id);
}
}
pub fn shutdown(
&mut self,
id: &TcpApiSocketId<I, C>,
shutdown_type: ShutdownType,
) -> Result<bool, NoConnection> {
debug!("shutdown [{shutdown_type:?}] for {id:?}");
let (core_ctx, bindings_ctx) = self.contexts();
let (result, pending) =
core_ctx.with_socket_mut_transport_demux(id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(_) => Err(NoConnection),
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn,
sharing: _,
timer,
}) => {
fn do_shutdown<SockI, WireI, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
id: &TcpSocketId<SockI, CC::WeakDeviceId, BC>,
demux_id: &WireI::DemuxSocketId<CC::WeakDeviceId, BC>,
conn: &mut Connection<SockI, WireI, CC::WeakDeviceId, BC>,
addr: &ConnAddr<
ConnIpAddr<<WireI as Ip>::Addr, NonZeroU16, NonZeroU16>,
CC::WeakDeviceId,
>,
timer: &mut BC::Timer,
shutdown_type: ShutdownType,
) -> Result<(), NoConnection>
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BC: TcpBindingsContext,
CC: TransportIpContext<WireI, BC>
+ TcpDemuxContext<WireI, CC::WeakDeviceId, BC>
+ CounterContext<TcpCounters<SockI>>,
{
let (shutdown_send, shutdown_receive) = shutdown_type.to_send_receive();
if shutdown_receive {
match conn.state.shutdown_recv() {
Ok(()) => (),
Err(CloseError::NoConnection) => return Err(NoConnection),
Err(CloseError::Closing) => (),
}
}
if !shutdown_send {
return Ok(());
}
match core_ctx.with_counters(|counters| {
conn.state.close(
counters,
CloseReason::Shutdown,
&conn.socket_options,
)
}) {
Ok(newly_closed) => {
let newly_closed = match newly_closed {
NewlyClosed::Yes => NewlyClosed::Yes,
NewlyClosed::No => do_send_inner(
id,
conn,
addr,
timer,
core_ctx,
bindings_ctx,
),
};
handle_newly_closed(
core_ctx,
bindings_ctx,
newly_closed,
demux_id,
addr,
timer,
);
Ok(())
}
Err(CloseError::NoConnection) => Err(NoConnection),
Err(CloseError::Closing) => Ok(()),
}
}
match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn);
do_shutdown(
core_ctx,
bindings_ctx,
id,
&I::into_demux_socket_id(id.clone()),
conn,
addr,
timer,
shutdown_type,
)?
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn) {
EitherStack::ThisStack((conn, addr)) => do_shutdown(
core_ctx,
bindings_ctx,
id,
&I::into_demux_socket_id(id.clone()),
conn,
addr,
timer,
shutdown_type,
)?,
EitherStack::OtherStack((conn, addr)) => do_shutdown(
core_ctx,
bindings_ctx,
id,
&core_ctx.into_other_demux_socket_id(id.clone()),
conn,
addr,
timer,
shutdown_type,
)?,
}
}
};
Ok((true, None))
}
TcpSocketStateInner::Bound(BoundSocketState::Listener((
maybe_listener,
sharing,
addr,
))) => {
let (_shutdown_send, shutdown_receive) = shutdown_type.to_send_receive();
if !shutdown_receive {
return Ok((false, None));
}
match maybe_listener {
MaybeListener::Bound(_) => return Err(NoConnection),
MaybeListener::Listener(_) => {}
}
let new_sharing = {
let ListenerSharingState { sharing, listening } = sharing;
assert!(*listening, "listener {id:?} is not listening");
ListenerSharingState { listening: false, sharing: sharing.clone() }
};
*sharing = try_update_listener_sharing::<_, C::CoreContext, _>(
core_ctx,
id,
addr.clone(),
sharing,
new_sharing,
)
.unwrap_or_else(|e| {
unreachable!(
"downgrading a TCP listener to bound should not fail, got {e:?}"
)
});
let queued_items =
replace_with::replace_with_and(maybe_listener, |maybe_listener| {
let Listener {
backlog: _,
accept_queue,
buffer_sizes,
socket_options,
} = assert_matches!(maybe_listener,
MaybeListener::Listener(l) => l, "must be a listener");
let (pending, socket_extra) = accept_queue.close();
let bound_state = BoundState {
buffer_sizes,
socket_options,
socket_extra: Takeable::new(socket_extra),
};
(MaybeListener::Bound(bound_state), pending)
});
Ok((false, Some(queued_items)))
}
}
})?;
close_pending_sockets(core_ctx, bindings_ctx, pending.into_iter().flatten());
Ok(result)
}
pub fn on_receive_buffer_read(&mut self, id: &TcpApiSocketId<I, C>) {
let (core_ctx, bindings_ctx) = self.contexts();
core_ctx.with_socket_mut_transport_demux(
id,
|core_ctx, TcpSocketState { socket_state, ip_options: _ }| {
let conn = match socket_state {
TcpSocketStateInner::Unbound(_) => return,
TcpSocketStateInner::Bound(bound) => match bound {
BoundSocketState::Listener(_) => return,
BoundSocketState::Connected { conn, sharing: _, timer: _ } => conn,
},
};
match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn);
if let Some(ack) = conn.state.poll_receive_data_dequeued() {
send_tcp_segment(
core_ctx,
bindings_ctx,
Some(id),
Some(&conn.ip_sock),
addr.ip,
ack.into_empty(),
&conn.socket_options.ip_options,
)
}
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn) {
EitherStack::ThisStack((conn, addr)) => {
if let Some(ack) = conn.state.poll_receive_data_dequeued() {
send_tcp_segment(
core_ctx,
bindings_ctx,
Some(id),
Some(&conn.ip_sock),
addr.ip,
ack.into_empty(),
&conn.socket_options.ip_options,
)
}
}
EitherStack::OtherStack((conn, addr)) => {
if let Some(ack) = conn.state.poll_receive_data_dequeued() {
send_tcp_segment(
core_ctx,
bindings_ctx,
Some(id),
Some(&conn.ip_sock),
addr.ip,
ack.into_empty(),
&conn.socket_options.ip_options,
)
}
}
}
}
}
},
)
}
fn set_device_conn<SockI, WireI, CC>(
core_ctx: &mut CC,
bindings_ctx: &mut C::BindingsContext,
addr: &mut ConnAddr<ConnIpAddr<WireI::Addr, NonZeroU16, NonZeroU16>, CC::WeakDeviceId>,
demux_id: &WireI::DemuxSocketId<CC::WeakDeviceId, C::BindingsContext>,
conn: &mut Connection<SockI, WireI, CC::WeakDeviceId, C::BindingsContext>,
new_device: Option<CC::DeviceId>,
) -> Result<(), SetDeviceError>
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
CC: TransportIpContext<WireI, C::BindingsContext>
+ TcpDemuxContext<WireI, CC::WeakDeviceId, C::BindingsContext>,
{
let ConnAddr {
device: old_device,
ip: ConnIpAddr { local: (local_ip, _), remote: (remote_ip, _) },
} = addr;
let update = SocketDeviceUpdate {
local_ip: Some(local_ip.as_ref()),
remote_ip: Some(remote_ip.as_ref()),
old_device: old_device.as_ref(),
};
match update.check_update(new_device.as_ref()) {
Ok(()) => (),
Err(SocketDeviceUpdateNotAllowedError) => return Err(SetDeviceError::ZoneChange),
}
let new_socket = core_ctx
.new_ip_socket(
bindings_ctx,
new_device.as_ref().map(EitherDeviceId::Strong),
IpDeviceAddr::new_from_socket_ip_addr(*local_ip),
*remote_ip,
IpProto::Tcp.into(),
&conn.socket_options.ip_options,
)
.map_err(|_: IpSockCreationError| SetDeviceError::Unroutable)?;
core_ctx.with_demux_mut(|DemuxState { socketmap }| {
let entry = socketmap
.conns_mut()
.entry(demux_id, addr)
.unwrap_or_else(|| panic!("invalid listener ID {:?}", demux_id));
match entry
.try_update_addr(ConnAddr { device: new_socket.device().cloned(), ..addr.clone() })
{
Ok(entry) => {
*addr = entry.get_addr().clone();
conn.ip_sock = new_socket;
Ok(())
}
Err((ExistsError, _entry)) => Err(SetDeviceError::Conflict),
}
})
}
fn set_device_listener<WireI, D>(
demux_id: &WireI::DemuxSocketId<D, C::BindingsContext>,
ip_addr: ListenerIpAddr<WireI::Addr, NonZeroU16>,
old_device: &mut Option<D>,
new_device: Option<&D>,
DemuxState { socketmap }: &mut DemuxState<WireI, D, C::BindingsContext>,
) -> Result<(), SetDeviceError>
where
WireI: DualStackIpExt,
D: WeakDeviceIdentifier,
{
let entry = socketmap
.listeners_mut()
.entry(demux_id, &ListenerAddr { ip: ip_addr, device: old_device.clone() })
.expect("invalid ID");
let update = SocketDeviceUpdate {
local_ip: ip_addr.addr.as_ref().map(|a| a.as_ref()),
remote_ip: None,
old_device: old_device.as_ref(),
};
match update.check_update(new_device) {
Ok(()) => (),
Err(SocketDeviceUpdateNotAllowedError) => return Err(SetDeviceError::ZoneChange),
}
match entry.try_update_addr(ListenerAddr { device: new_device.cloned(), ip: ip_addr }) {
Ok(entry) => {
*old_device = entry.get_addr().device.clone();
Ok(())
}
Err((ExistsError, _entry)) => Err(SetDeviceError::Conflict),
}
}
pub fn set_device(
&mut self,
id: &TcpApiSocketId<I, C>,
new_device: Option<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
) -> Result<(), SetDeviceError> {
let (core_ctx, bindings_ctx) = self.contexts();
let weak_device = new_device.as_ref().map(|d| d.downgrade());
core_ctx.with_socket_mut_transport_demux(id, move |core_ctx, socket_state| {
debug!("set device on {id:?} to {new_device:?}");
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(unbound) => {
unbound.bound_device = weak_device;
Ok(())
}
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn: conn_and_addr,
sharing: _,
timer: _,
}) => {
let this_or_other_stack = match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn_and_addr);
EitherStack::ThisStack((
core_ctx.as_this_stack(),
conn,
addr,
I::into_demux_socket_id(id.clone()),
))
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn_and_addr) {
EitherStack::ThisStack((conn, addr)) => EitherStack::ThisStack((
core_ctx.as_this_stack(),
conn,
addr,
I::into_demux_socket_id(id.clone()),
)),
EitherStack::OtherStack((conn, addr)) => {
let demux_id = core_ctx.into_other_demux_socket_id(id.clone());
EitherStack::OtherStack((core_ctx, conn, addr, demux_id))
}
}
}
};
match this_or_other_stack {
EitherStack::ThisStack((core_ctx, conn, addr, demux_id)) => {
Self::set_device_conn::<_, I, _>(
core_ctx,
bindings_ctx,
addr,
&demux_id,
conn,
new_device,
)
}
EitherStack::OtherStack((core_ctx, conn, addr, demux_id)) => {
Self::set_device_conn::<_, I::OtherVersion, _>(
core_ctx,
bindings_ctx,
addr,
&demux_id,
conn,
new_device,
)
}
}
}
TcpSocketStateInner::Bound(BoundSocketState::Listener((
_listener,
_sharing,
addr,
))) => match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let ListenerAddr { ip, device } = converter.convert(addr);
core_ctx.with_demux_mut(|demux| {
Self::set_device_listener(
&I::into_demux_socket_id(id.clone()),
ip.clone(),
device,
weak_device.as_ref(),
demux,
)
})
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(addr) {
ListenerAddr { ip: DualStackListenerIpAddr::ThisStack(ip), device } => {
TcpDemuxContext::<I, _, _>::with_demux_mut(core_ctx, |demux| {
Self::set_device_listener(
&I::into_demux_socket_id(id.clone()),
ip.clone(),
device,
weak_device.as_ref(),
demux,
)
})
}
ListenerAddr {
ip: DualStackListenerIpAddr::OtherStack(ip),
device,
} => {
let other_demux_id =
core_ctx.into_other_demux_socket_id(id.clone());
TcpDemuxContext::<I::OtherVersion, _, _>::with_demux_mut(
core_ctx,
|demux| {
Self::set_device_listener(
&other_demux_id,
ip.clone(),
device,
weak_device.as_ref(),
demux,
)
},
)
}
ListenerAddr {
ip: DualStackListenerIpAddr::BothStacks(port),
device,
} => {
let other_demux_id =
core_ctx.into_other_demux_socket_id(id.clone());
core_ctx.with_both_demux_mut(|demux, other_demux| {
Self::set_device_listener(
&I::into_demux_socket_id(id.clone()),
ListenerIpAddr { addr: None, identifier: *port },
device,
weak_device.as_ref(),
demux,
)?;
match Self::set_device_listener(
&other_demux_id,
ListenerIpAddr { addr: None, identifier: *port },
device,
weak_device.as_ref(),
other_demux,
) {
Ok(()) => Ok(()),
Err(e) => {
Self::set_device_listener(
&I::into_demux_socket_id(id.clone()),
ListenerIpAddr { addr: None, identifier: *port },
device,
device.clone().as_ref(),
demux,
)
.expect("failed to revert back the device setting");
Err(e)
}
}
})
}
}
}
},
}
})
}
pub fn get_info(
&mut self,
id: &TcpApiSocketId<I, C>,
) -> SocketInfo<I::Addr, <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
self.core_ctx().with_socket_and_converter(
id,
|TcpSocketState { socket_state, ip_options: _ }, _converter| match socket_state {
TcpSocketStateInner::Unbound(unbound) => SocketInfo::Unbound(unbound.into()),
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn: conn_and_addr,
sharing: _,
timer: _,
}) => SocketInfo::Connection(I::get_conn_info(conn_and_addr)),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
_listener,
_sharing,
addr,
))) => SocketInfo::Bound(I::get_bound_info(addr)),
},
)
}
pub fn do_send(&mut self, conn_id: &TcpApiSocketId<I, C>) {
let (core_ctx, bindings_ctx) = self.contexts();
core_ctx.with_socket_mut_transport_demux(conn_id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
let (conn, timer) = assert_matches!(
socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn, sharing: _, timer
}) => (conn, timer)
);
match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn);
do_send_inner_and_then_handle_newly_closed(
conn_id,
I::into_demux_socket_id(conn_id.clone()),
conn,
addr,
timer,
core_ctx,
bindings_ctx,
);
}
MaybeDualStack::DualStack((core_ctx, converter)) => match converter.convert(conn) {
EitherStack::ThisStack((conn, addr)) => {
do_send_inner_and_then_handle_newly_closed(
conn_id,
I::into_demux_socket_id(conn_id.clone()),
conn,
addr,
timer,
core_ctx,
bindings_ctx,
)
}
EitherStack::OtherStack((conn, addr)) => {
let other_demux_id = core_ctx.into_other_demux_socket_id(conn_id.clone());
do_send_inner_and_then_handle_newly_closed(
conn_id,
other_demux_id,
conn,
addr,
timer,
core_ctx,
bindings_ctx,
);
}
},
};
})
}
fn handle_timer(
&mut self,
weak_id: WeakTcpSocketId<
I,
<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
C::BindingsContext,
>,
) {
let id = match weak_id.upgrade() {
Some(c) => c,
None => return,
};
let (core_ctx, bindings_ctx) = self.contexts();
debug!("handle_timer on {id:?}");
trace_duration!(bindings_ctx, c"tcp::handle_timer");
let id_alias = &id;
let bindings_ctx_alias = &mut *bindings_ctx;
let closed_and_defunct =
core_ctx.with_socket_mut_transport_demux(&id, move |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
let id = id_alias;
let bindings_ctx = bindings_ctx_alias;
let (conn, timer) = assert_matches!(
socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected{ conn, sharing: _, timer}) => (conn, timer)
);
fn do_handle_timer<SockI, WireI, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
id: &TcpSocketId<SockI, CC::WeakDeviceId, BC>,
demux_id: &WireI::DemuxSocketId<CC::WeakDeviceId, BC>,
conn: &mut Connection<SockI, WireI, CC::WeakDeviceId, BC>,
addr: &ConnAddr<
ConnIpAddr<<WireI as Ip>::Addr, NonZeroU16, NonZeroU16>,
CC::WeakDeviceId,
>,
timer: &mut BC::Timer,
) -> bool
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BC: TcpBindingsContext,
CC: TransportIpContext<WireI, BC>
+ TcpDemuxContext<WireI, CC::WeakDeviceId, BC>
+ CounterContext<TcpCounters<SockI>>,
{
let time_wait = matches!(conn.state, State::TimeWait(_));
let newly_closed = do_send_inner(id, conn, addr, timer, core_ctx, bindings_ctx);
match (newly_closed, time_wait) {
(NewlyClosed::Yes, time_wait) => {
let result = core_ctx.with_demux_mut(|DemuxState { socketmap }| {
socketmap
.conns_mut()
.remove(demux_id, addr)
});
result.unwrap_or_else(|e| {
if time_wait {
debug!(
"raced with timewait removal for {id:?} {addr:?}: {e:?}"
);
} else {
panic!("failed to remove from socketmap: {e:?}");
}
});
let _: Option<_> = bindings_ctx.cancel_timer(timer);
}
(NewlyClosed::No, _) => {},
}
conn.defunct && matches!(conn.state, State::Closed(_))
}
match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn);
do_handle_timer(
core_ctx,
bindings_ctx,
id,
&I::into_demux_socket_id(id.clone()),
conn,
addr,
timer,
)
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn) {
EitherStack::ThisStack((conn, addr)) => do_handle_timer(
core_ctx,
bindings_ctx,
id,
&I::into_demux_socket_id(id.clone()),
conn,
addr,
timer,
),
EitherStack::OtherStack((conn, addr)) => do_handle_timer(
core_ctx,
bindings_ctx,
id,
&core_ctx.into_other_demux_socket_id(id.clone()),
conn,
addr,
timer,
),
}
}
}
});
if closed_and_defunct {
destroy_socket(core_ctx, bindings_ctx, id);
}
}
pub fn with_socket_options_mut<R, F: FnOnce(&mut SocketOptions) -> R>(
&mut self,
id: &TcpApiSocketId<I, C>,
f: F,
) -> R {
let (core_ctx, bindings_ctx) = self.contexts();
core_ctx.with_socket_mut_transport_demux(id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(unbound) => f(&mut unbound.socket_options),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Bound(bound),
_,
_,
))) => f(&mut bound.socket_options),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(listener),
_,
_,
))) => f(&mut listener.socket_options),
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn,
sharing: _,
timer,
}) => match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn);
let old = conn.socket_options;
let result = f(&mut conn.socket_options);
if old != conn.socket_options {
do_send_inner_and_then_handle_newly_closed(
id,
I::into_demux_socket_id(id.clone()),
conn,
&*addr,
timer,
core_ctx,
bindings_ctx,
);
}
result
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn) {
EitherStack::ThisStack((conn, addr)) => {
let old = conn.socket_options;
let result = f(&mut conn.socket_options);
if old != conn.socket_options {
do_send_inner_and_then_handle_newly_closed(
id,
I::into_demux_socket_id(id.clone()),
conn,
&*addr,
timer,
core_ctx,
bindings_ctx,
);
}
result
}
EitherStack::OtherStack((conn, addr)) => {
let old = conn.socket_options;
let result = f(&mut conn.socket_options);
if old != conn.socket_options {
let other_demux_id =
core_ctx.into_other_demux_socket_id(id.clone());
do_send_inner_and_then_handle_newly_closed(
id,
other_demux_id,
conn,
&*addr,
timer,
core_ctx,
bindings_ctx,
);
}
result
}
}
}
},
}
})
}
pub fn with_socket_options<R, F: FnOnce(&SocketOptions) -> R>(
&mut self,
id: &TcpApiSocketId<I, C>,
f: F,
) -> R {
self.core_ctx().with_socket_and_converter(
id,
|TcpSocketState { socket_state, ip_options: _ }, converter| match socket_state {
TcpSocketStateInner::Unbound(unbound) => f(&unbound.socket_options),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Bound(bound),
_,
_,
))) => f(&bound.socket_options),
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(listener),
_,
_,
))) => f(&listener.socket_options),
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn,
sharing: _,
timer: _,
}) => {
let socket_options = match converter {
MaybeDualStack::NotDualStack(converter) => {
let (conn, _addr) = converter.convert(conn);
&conn.socket_options
}
MaybeDualStack::DualStack(converter) => match converter.convert(conn) {
EitherStack::ThisStack((conn, _addr)) => &conn.socket_options,
EitherStack::OtherStack((conn, _addr)) => &conn.socket_options,
},
};
f(socket_options)
}
},
)
}
pub fn set_send_buffer_size(&mut self, id: &TcpApiSocketId<I, C>, size: usize) {
set_buffer_size::<SendBufferSize, I, _, _>(self.core_ctx(), id, size)
}
pub fn send_buffer_size(&mut self, id: &TcpApiSocketId<I, C>) -> Option<usize> {
get_buffer_size::<SendBufferSize, I, _, _>(self.core_ctx(), id)
}
pub fn set_receive_buffer_size(&mut self, id: &TcpApiSocketId<I, C>, size: usize) {
set_buffer_size::<ReceiveBufferSize, I, _, _>(self.core_ctx(), id, size)
}
pub fn receive_buffer_size(&mut self, id: &TcpApiSocketId<I, C>) -> Option<usize> {
get_buffer_size::<ReceiveBufferSize, I, _, _>(self.core_ctx(), id)
}
pub fn set_reuseaddr(
&mut self,
id: &TcpApiSocketId<I, C>,
reuse: bool,
) -> Result<(), SetReuseAddrError> {
let new_sharing = match reuse {
true => SharingState::ReuseAddress,
false => SharingState::Exclusive,
};
self.core_ctx().with_socket_mut_transport_demux(id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(unbound) => {
unbound.sharing = new_sharing;
Ok(())
}
TcpSocketStateInner::Bound(BoundSocketState::Listener((
_listener,
old_sharing,
addr,
))) => {
if new_sharing == old_sharing.sharing {
return Ok(());
}
let new_sharing = {
let ListenerSharingState { sharing: _, listening } = old_sharing;
ListenerSharingState { sharing: new_sharing, listening: *listening }
};
*old_sharing = try_update_listener_sharing::<_, C::CoreContext, _>(
core_ctx,
id,
addr.clone(),
old_sharing,
new_sharing,
)
.map_err(|UpdateSharingError| SetReuseAddrError::AddrInUse)?;
Ok(())
}
TcpSocketStateInner::Bound(BoundSocketState::Connected { .. }) => {
Err(SetReuseAddrError::NotSupported)
}
}
})
}
pub fn reuseaddr(&mut self, id: &TcpApiSocketId<I, C>) -> bool {
self.core_ctx().with_socket(id, |TcpSocketState { socket_state, ip_options: _ }| {
match socket_state {
TcpSocketStateInner::Unbound(Unbound { sharing, .. })
| TcpSocketStateInner::Bound(
BoundSocketState::Connected { sharing, .. }
| BoundSocketState::Listener((_, ListenerSharingState { sharing, .. }, _)),
) => match sharing {
SharingState::Exclusive => false,
SharingState::ReuseAddress => true,
},
}
})
}
pub fn dual_stack_enabled(
&mut self,
id: &TcpSocketId<
I,
<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
C::BindingsContext,
>,
) -> Result<bool, NotDualStackCapableError> {
self.core_ctx().with_socket_mut_transport_demux(
id,
|core_ctx, TcpSocketState { socket_state: _, ip_options }| match core_ctx {
MaybeDualStack::NotDualStack(_) => Err(NotDualStackCapableError),
MaybeDualStack::DualStack((core_ctx, _converter)) => {
Ok(core_ctx.dual_stack_enabled(ip_options))
}
},
)
}
pub fn set_mark(&mut self, id: &TcpApiSocketId<I, C>, domain: MarkDomain, mark: Mark) {
self.with_socket_options_mut(id, |options| *options.ip_options.marks.get_mut(domain) = mark)
}
pub fn get_mark(&mut self, id: &TcpApiSocketId<I, C>, domain: MarkDomain) -> Mark {
self.with_socket_options(id, |options| *options.ip_options.marks.get(domain))
}
pub fn set_dual_stack_enabled(
&mut self,
id: &TcpSocketId<
I,
<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
C::BindingsContext,
>,
value: bool,
) -> Result<(), SetDualStackEnabledError> {
self.core_ctx().with_socket_mut_transport_demux(id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options } = socket_state;
match core_ctx {
MaybeDualStack::NotDualStack(_) => Err(SetDualStackEnabledError::NotCapable),
MaybeDualStack::DualStack((core_ctx, _converter)) => match socket_state {
TcpSocketStateInner::Unbound(_) => {
Ok(core_ctx.set_dual_stack_enabled(ip_options, value))
}
TcpSocketStateInner::Bound(_) => Err(SetDualStackEnabledError::SocketIsBound),
},
}
})
}
fn on_icmp_error_conn(
core_ctx: &mut C::CoreContext,
bindings_ctx: &mut C::BindingsContext,
id: TcpSocketId<
I,
<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
C::BindingsContext,
>,
seq: SeqNum,
error: IcmpErrorCode,
) {
let destroy = core_ctx.with_socket_mut_transport_demux(&id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
let (conn_and_addr, timer) = assert_matches!(
socket_state,
TcpSocketStateInner::Bound(
BoundSocketState::Connected { conn, sharing: _, timer } ) => (conn, timer),
"invalid socket ID");
let (
newly_closed,
accept_queue,
state,
soft_error,
handshake_status,
this_or_other_stack,
) = match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn_and_addr);
let newly_closed = conn.on_icmp_error(core_ctx, seq, error);
(
newly_closed,
&mut conn.accept_queue,
&mut conn.state,
&mut conn.soft_error,
&mut conn.handshake_status,
EitherStack::ThisStack((
core_ctx.as_this_stack(),
I::into_demux_socket_id(id.clone()),
addr,
)),
)
}
MaybeDualStack::DualStack((core_ctx, converter)) => {
match converter.convert(conn_and_addr) {
EitherStack::ThisStack((conn, addr)) => {
let newly_closed = conn.on_icmp_error(core_ctx, seq, error);
(
newly_closed,
&mut conn.accept_queue,
&mut conn.state,
&mut conn.soft_error,
&mut conn.handshake_status,
EitherStack::ThisStack((
core_ctx.as_this_stack(),
I::into_demux_socket_id(id.clone()),
addr,
)),
)
}
EitherStack::OtherStack((conn, addr)) => {
let newly_closed = conn.on_icmp_error(core_ctx, seq, error);
let demux_id = core_ctx.into_other_demux_socket_id(id.clone());
(
newly_closed,
&mut conn.accept_queue,
&mut conn.state,
&mut conn.soft_error,
&mut conn.handshake_status,
EitherStack::OtherStack((core_ctx, demux_id, addr)),
)
}
}
}
};
if let State::Closed(Closed { reason }) = state {
debug!("handshake_status: {handshake_status:?}");
let _: bool = handshake_status.update_if_pending(HandshakeStatus::Aborted);
match this_or_other_stack {
EitherStack::ThisStack((core_ctx, demux_id, addr)) => {
handle_newly_closed::<I, _, _, _>(
core_ctx,
bindings_ctx,
newly_closed,
&demux_id,
addr,
timer,
);
}
EitherStack::OtherStack((core_ctx, demux_id, addr)) => {
handle_newly_closed::<I::OtherVersion, _, _, _>(
core_ctx,
bindings_ctx,
newly_closed,
&demux_id,
addr,
timer,
);
}
};
match accept_queue {
Some(accept_queue) => {
accept_queue.remove(&id);
return true;
}
None => {
if let Some(err) = reason {
if *err == ConnectionError::TimedOut {
*err = soft_error.unwrap_or(ConnectionError::TimedOut);
}
}
}
}
}
false
});
if destroy {
destroy_socket(core_ctx, bindings_ctx, id);
}
}
fn on_icmp_error(
&mut self,
orig_src_ip: SpecifiedAddr<I::Addr>,
orig_dst_ip: SpecifiedAddr<I::Addr>,
orig_src_port: NonZeroU16,
orig_dst_port: NonZeroU16,
seq: SeqNum,
error: IcmpErrorCode,
) where
C::CoreContext: TcpContext<I::OtherVersion, C::BindingsContext>
+ CounterContext<TcpCounters<I::OtherVersion>>,
C::BindingsContext: TcpBindingsContext,
{
let (core_ctx, bindings_ctx) = self.contexts();
let orig_src_ip = match SocketIpAddr::try_from(orig_src_ip) {
Ok(ip) => ip,
Err(AddrIsMappedError {}) => {
trace!("ignoring ICMP error from IPv4-mapped-IPv6 source: {}", orig_src_ip);
return;
}
};
let orig_dst_ip = match SocketIpAddr::try_from(orig_dst_ip) {
Ok(ip) => ip,
Err(AddrIsMappedError {}) => {
trace!("ignoring ICMP error to IPv4-mapped-IPv6 destination: {}", orig_dst_ip);
return;
}
};
let id = TcpDemuxContext::<I, _, _>::with_demux(core_ctx, |DemuxState { socketmap }| {
socketmap
.conns()
.get_by_addr(&ConnAddr {
ip: ConnIpAddr {
local: (orig_src_ip, orig_src_port),
remote: (orig_dst_ip, orig_dst_port),
},
device: None,
})
.map(|ConnAddrState { sharing: _, id }| id.clone())
});
let id = match id {
Some(id) => id,
None => return,
};
match I::into_dual_stack_ip_socket(id) {
EitherStack::ThisStack(id) => {
Self::on_icmp_error_conn(core_ctx, bindings_ctx, id, seq, error)
}
EitherStack::OtherStack(id) => TcpApi::<I::OtherVersion, C>::on_icmp_error_conn(
core_ctx,
bindings_ctx,
id,
seq,
error,
),
};
}
pub fn get_socket_error(&mut self, id: &TcpApiSocketId<I, C>) -> Option<ConnectionError> {
self.core_ctx().with_socket_mut_and_converter(id, |socket_state, converter| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(_)
| TcpSocketStateInner::Bound(BoundSocketState::Listener(_)) => None,
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn,
sharing: _,
timer: _,
}) => {
let (state, soft_error) = match converter {
MaybeDualStack::NotDualStack(converter) => {
let (conn, _addr) = converter.convert(conn);
(&conn.state, &mut conn.soft_error)
}
MaybeDualStack::DualStack(converter) => match converter.convert(conn) {
EitherStack::ThisStack((conn, _addr)) => {
(&conn.state, &mut conn.soft_error)
}
EitherStack::OtherStack((conn, _addr)) => {
(&conn.state, &mut conn.soft_error)
}
},
};
let hard_error = if let State::Closed(Closed { reason: hard_error }) = state {
hard_error.clone()
} else {
None
};
hard_error.or_else(|| soft_error.take())
}
}
})
}
pub fn get_original_destination(
&mut self,
id: &TcpApiSocketId<I, C>,
) -> Result<(SpecifiedAddr<I::Addr>, NonZeroU16), OriginalDestinationError> {
self.core_ctx().with_socket_mut_transport_demux(id, |core_ctx, state| {
let TcpSocketState { socket_state, .. } = state;
let conn = match socket_state {
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => conn,
TcpSocketStateInner::Bound(BoundSocketState::Listener(_))
| TcpSocketStateInner::Unbound(_) => {
return Err(OriginalDestinationError::NotConnected)
}
};
fn tuple<I: IpExt>(
ConnIpAddr { local, remote }: ConnIpAddr<I::Addr, NonZeroU16, NonZeroU16>,
) -> Tuple<I> {
let (local_addr, local_port) = local;
let (remote_addr, remote_port) = remote;
Tuple {
protocol: IpProto::Tcp.into(),
src_addr: local_addr.addr(),
dst_addr: remote_addr.addr(),
src_port_or_id: local_port.get(),
dst_port_or_id: remote_port.get(),
}
}
let (addr, port) = match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (_conn, addr) = converter.convert(conn);
let tuple: Tuple<I> = tuple(addr.ip);
core_ctx
.get_original_destination(&tuple)
.ok_or(OriginalDestinationError::NotFound)
}
MaybeDualStack::DualStack((core_ctx, converter)) => match converter.convert(conn) {
EitherStack::ThisStack((_conn, addr)) => {
let tuple: Tuple<I> = tuple(addr.ip);
let (addr, port) = core_ctx
.get_original_destination(&tuple)
.ok_or(OriginalDestinationError::NotFound)?;
let addr = I::get_original_dst(
converter.convert_back(EitherStack::ThisStack(addr)),
);
Ok((addr, port))
}
EitherStack::OtherStack((_conn, addr)) => {
let tuple: Tuple<I::OtherVersion> = tuple(addr.ip);
let (addr, port) = core_ctx
.get_original_destination(&tuple)
.ok_or(OriginalDestinationError::NotFound)?;
let addr = I::get_original_dst(
converter.convert_back(EitherStack::OtherStack(addr)),
);
Ok((addr, port))
}
},
}?;
let addr = SpecifiedAddr::new(addr).ok_or_else(|| {
error!("original destination for socket {id:?} had unspecified addr (port {port})");
OriginalDestinationError::UnspecifiedDestinationAddr
})?;
let port = NonZeroU16::new(port).ok_or_else(|| {
error!("original destination for socket {id:?} had unspecified port (addr {addr})");
OriginalDestinationError::UnspecifiedDestinationPort
})?;
Ok((addr, port))
})
}
pub fn inspect<N>(&mut self, inspector: &mut N)
where
N: Inspector
+ InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
{
self.core_ctx().for_each_socket(|socket_id, socket_state| {
inspector.record_debug_child(socket_id, |node| {
node.record_str("TransportProtocol", "TCP");
node.record_str(
"NetworkProtocol",
match I::VERSION {
IpVersion::V4 => "IPv4",
IpVersion::V6 => "IPv6",
},
);
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
match socket_state {
TcpSocketStateInner::Unbound(_) => {
node.record_local_socket_addr::<N, I::Addr, _, NonZeroU16>(None);
node.record_remote_socket_addr::<N, I::Addr, _, NonZeroU16>(None);
}
TcpSocketStateInner::Bound(BoundSocketState::Listener((
state,
_sharing,
addr,
))) => {
let BoundInfo { addr, port, device } = I::get_bound_info(addr);
let local = addr.map_or_else(
|| ZonedAddr::Unzoned(I::UNSPECIFIED_ADDRESS),
|addr| maybe_zoned(addr.addr(), &device).into(),
);
node.record_local_socket_addr::<N, _, _, _>(Some((local, port)));
node.record_remote_socket_addr::<N, I::Addr, _, NonZeroU16>(None);
match state {
MaybeListener::Bound(_bound_state) => {}
MaybeListener::Listener(Listener { accept_queue, backlog, .. }) => node
.record_child("AcceptQueue", |node| {
node.record_usize("BacklogSize", *backlog);
accept_queue.inspect(node);
}),
};
}
TcpSocketStateInner::Bound(BoundSocketState::Connected {
conn: conn_and_addr,
..
}) => {
if I::get_defunct(conn_and_addr) {
return;
}
let state = I::get_state(conn_and_addr);
let ConnectionInfo {
local_addr: SocketAddr { ip: local_ip, port: local_port },
remote_addr: SocketAddr { ip: remote_ip, port: remote_port },
device: _,
} = I::get_conn_info(conn_and_addr);
node.record_local_socket_addr::<N, I::Addr, _, _>(Some((
local_ip.into(),
local_port,
)));
node.record_remote_socket_addr::<N, I::Addr, _, _>(Some((
remote_ip.into(),
remote_port,
)));
node.record_display("State", state);
}
}
});
})
}
pub fn with_send_buffer<
R,
F: FnOnce(&mut <C::BindingsContext as TcpBindingsTypes>::SendBuffer) -> R,
>(
&mut self,
id: &TcpApiSocketId<I, C>,
f: F,
) -> Option<R> {
self.core_ctx().with_socket_mut_and_converter(id, |state, converter| {
get_buffers_mut::<_, C::CoreContext, _>(state, converter).into_send_buffer().map(f)
})
}
pub fn with_receive_buffer<
R,
F: FnOnce(&mut <C::BindingsContext as TcpBindingsTypes>::ReceiveBuffer) -> R,
>(
&mut self,
id: &TcpApiSocketId<I, C>,
f: F,
) -> Option<R> {
self.core_ctx().with_socket_mut_and_converter(id, |state, converter| {
get_buffers_mut::<_, C::CoreContext, _>(state, converter).into_receive_buffer().map(f)
})
}
}
fn destroy_socket<I: DualStackIpExt, CC: TcpContext<I, BC>, BC: TcpBindingsContext>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
id: TcpSocketId<I, CC::WeakDeviceId, BC>,
) {
let weak = id.downgrade();
core_ctx.with_all_sockets_mut(move |all_sockets| {
let TcpSocketId(rc) = &id;
let debug_refs = StrongRc::debug_references(rc);
let entry = all_sockets.entry(id);
let primary = match entry {
hash_map::Entry::Occupied(o) => match o.get() {
TcpSocketSetEntry::DeadOnArrival => {
let id = o.key();
debug!("{id:?} destruction skipped, socket is DOA. References={debug_refs:?}",);
None
}
TcpSocketSetEntry::Primary(_) => {
assert_matches!(o.remove_entry(), (_, TcpSocketSetEntry::Primary(p)) => Some(p))
}
},
hash_map::Entry::Vacant(v) => {
let id = v.key();
let TcpSocketId(rc) = id;
if !StrongRc::marked_for_destruction(rc) {
debug!(
"{id:?} raced with insertion, marking socket as DOA. \
References={debug_refs:?}",
);
let _: &mut _ = v.insert(TcpSocketSetEntry::DeadOnArrival);
} else {
debug!("{id:?} destruction is already deferred. References={debug_refs:?}");
}
None
}
};
#[cfg(test)]
let primary = primary.unwrap_or_else(|| {
panic!("deferred destruction not allowed in tests. References={debug_refs:?}")
});
#[cfg(not(test))]
let Some(primary) = primary
else {
return;
};
let remove_result =
BC::unwrap_or_notify_with_new_reference_notifier(primary, |state| state);
match remove_result {
RemoveResourceResult::Removed(state) => debug!("destroyed {weak:?} {state:?}"),
RemoveResourceResult::Deferred(receiver) => {
debug!("deferred removal {weak:?}");
bindings_ctx.defer_removal(receiver)
}
}
})
}
fn close_pending_sockets<I, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
pending: impl Iterator<Item = TcpSocketId<I, CC::WeakDeviceId, BC>>,
) where
I: DualStackIpExt,
BC: TcpBindingsContext,
CC: TcpContext<I, BC>,
{
for conn_id in pending {
core_ctx.with_socket_mut_transport_demux(&conn_id, |core_ctx, socket_state| {
let TcpSocketState { socket_state, ip_options: _ } = socket_state;
let (conn_and_addr, timer) = assert_matches!(
socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected{
conn, sharing: _, timer
}) => (conn, timer),
"invalid socket ID"
);
let _: Option<BC::Instant> = bindings_ctx.cancel_timer(timer);
let this_or_other_stack = match core_ctx {
MaybeDualStack::NotDualStack((core_ctx, converter)) => {
let (conn, addr) = converter.convert(conn_and_addr);
EitherStack::ThisStack((
core_ctx.as_this_stack(),
I::into_demux_socket_id(conn_id.clone()),
conn,
addr.clone(),
))
}
MaybeDualStack::DualStack((core_ctx, converter)) => match converter
.convert(conn_and_addr)
{
EitherStack::ThisStack((conn, addr)) => EitherStack::ThisStack((
core_ctx.as_this_stack(),
I::into_demux_socket_id(conn_id.clone()),
conn,
addr.clone(),
)),
EitherStack::OtherStack((conn, addr)) => {
let other_demux_id = core_ctx.into_other_demux_socket_id(conn_id.clone());
EitherStack::OtherStack((core_ctx, other_demux_id, conn, addr.clone()))
}
},
};
match this_or_other_stack {
EitherStack::ThisStack((core_ctx, demux_id, conn, conn_addr)) => {
close_pending_socket(
core_ctx,
bindings_ctx,
&conn_id,
&demux_id,
timer,
conn,
&conn_addr,
)
}
EitherStack::OtherStack((core_ctx, demux_id, conn, conn_addr)) => {
close_pending_socket(
core_ctx,
bindings_ctx,
&conn_id,
&demux_id,
timer,
conn,
&conn_addr,
)
}
}
});
destroy_socket(core_ctx, bindings_ctx, conn_id);
}
}
fn close_pending_socket<WireI, SockI, DC, BC>(
core_ctx: &mut DC,
bindings_ctx: &mut BC,
sock_id: &TcpSocketId<SockI, DC::WeakDeviceId, BC>,
demux_id: &WireI::DemuxSocketId<DC::WeakDeviceId, BC>,
timer: &mut BC::Timer,
conn: &mut Connection<SockI, WireI, DC::WeakDeviceId, BC>,
conn_addr: &ConnAddr<ConnIpAddr<WireI::Addr, NonZeroU16, NonZeroU16>, DC::WeakDeviceId>,
) where
WireI: DualStackIpExt,
SockI: DualStackIpExt,
DC: TransportIpContext<WireI, BC>
+ DeviceIpSocketHandler<WireI, BC>
+ TcpDemuxContext<WireI, DC::WeakDeviceId, BC>
+ CounterContext<TcpCounters<SockI>>,
BC: TcpBindingsContext,
{
debug!("aborting pending socket {sock_id:?}");
let (maybe_reset, newly_closed) = core_ctx.with_counters(|counters| conn.state.abort(counters));
handle_newly_closed(core_ctx, bindings_ctx, newly_closed, demux_id, conn_addr, timer);
if let Some(reset) = maybe_reset {
let ConnAddr { ip, device: _ } = conn_addr;
send_tcp_segment(
core_ctx,
bindings_ctx,
Some(sock_id),
Some(&conn.ip_sock),
*ip,
reset.into_empty(),
&conn.socket_options.ip_options,
);
}
}
fn do_send_inner_and_then_handle_newly_closed<SockI, WireI, CC, BC>(
conn_id: &TcpSocketId<SockI, CC::WeakDeviceId, BC>,
demux_id: WireI::DemuxSocketId<CC::WeakDeviceId, BC>,
conn: &mut Connection<SockI, WireI, CC::WeakDeviceId, BC>,
addr: &ConnAddr<ConnIpAddr<WireI::Addr, NonZeroU16, NonZeroU16>, CC::WeakDeviceId>,
timer: &mut BC::Timer,
core_ctx: &mut CC,
bindings_ctx: &mut BC,
) where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BC: TcpBindingsContext,
CC: TransportIpContext<WireI, BC>
+ CounterContext<TcpCounters<SockI>>
+ TcpDemuxContext<WireI, CC::WeakDeviceId, BC>,
{
let newly_closed = do_send_inner(conn_id, conn, addr, timer, core_ctx, bindings_ctx);
handle_newly_closed(core_ctx, bindings_ctx, newly_closed, &demux_id, addr, timer);
}
#[inline]
fn handle_newly_closed<I, D, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
newly_closed: NewlyClosed,
demux_id: &I::DemuxSocketId<D, BC>,
addr: &ConnAddr<ConnIpAddr<I::Addr, NonZeroU16, NonZeroU16>, D>,
timer: &mut BC::Timer,
) where
I: DualStackIpExt,
D: WeakDeviceIdentifier,
CC: TcpDemuxContext<I, D, BC>,
BC: TcpBindingsContext,
{
if newly_closed == NewlyClosed::Yes {
core_ctx.with_demux_mut(|DemuxState { socketmap }| {
socketmap.conns_mut().remove(demux_id, addr).expect("failed to remove from demux");
let _: Option<_> = bindings_ctx.cancel_timer(timer);
});
}
}
fn do_send_inner<SockI, WireI, CC, BC>(
conn_id: &TcpSocketId<SockI, CC::WeakDeviceId, BC>,
conn: &mut Connection<SockI, WireI, CC::WeakDeviceId, BC>,
addr: &ConnAddr<ConnIpAddr<WireI::Addr, NonZeroU16, NonZeroU16>, CC::WeakDeviceId>,
timer: &mut BC::Timer,
core_ctx: &mut CC,
bindings_ctx: &mut BC,
) -> NewlyClosed
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BC: TcpBindingsContext,
CC: TransportIpContext<WireI, BC> + CounterContext<TcpCounters<SockI>>,
{
let newly_closed = loop {
match core_ctx.with_counters(|counters| {
conn.state.poll_send(counters, u32::MAX, bindings_ctx.now(), &conn.socket_options)
}) {
Ok(seg) => {
send_tcp_segment(
core_ctx,
bindings_ctx,
Some(conn_id),
Some(&conn.ip_sock),
addr.ip.clone(),
seg,
&conn.socket_options.ip_options,
);
}
Err(newly_closed) => break newly_closed,
}
};
if let Some(instant) = conn.state.poll_send_at() {
let _: Option<_> = bindings_ctx.schedule_timer_instant(instant, timer);
}
newly_closed
}
enum SendBufferSize {}
enum ReceiveBufferSize {}
trait AccessBufferSize<R, S> {
fn set_buffer_size(buffers: BuffersRefMut<'_, R, S>, new_size: usize);
fn get_buffer_size(buffers: BuffersRefMut<'_, R, S>) -> Option<usize>;
fn allowed_range() -> (usize, usize);
}
impl<R: Buffer, S: Buffer> AccessBufferSize<R, S> for SendBufferSize {
fn set_buffer_size(buffers: BuffersRefMut<'_, R, S>, new_size: usize) {
match buffers {
BuffersRefMut::NoBuffers | BuffersRefMut::RecvOnly { .. } => {}
BuffersRefMut::Both { send, recv: _ } | BuffersRefMut::SendOnly(send) => {
send.request_capacity(new_size)
}
BuffersRefMut::Sizes(BufferSizes { send, receive: _ }) => *send = new_size,
}
}
fn allowed_range() -> (usize, usize) {
S::capacity_range()
}
fn get_buffer_size(buffers: BuffersRefMut<'_, R, S>) -> Option<usize> {
match buffers {
BuffersRefMut::NoBuffers | BuffersRefMut::RecvOnly { .. } => None,
BuffersRefMut::Both { send, recv: _ } | BuffersRefMut::SendOnly(send) => {
Some(send.target_capacity())
}
BuffersRefMut::Sizes(BufferSizes { send, receive: _ }) => Some(*send),
}
}
}
impl<R: Buffer, S: Buffer> AccessBufferSize<R, S> for ReceiveBufferSize {
fn set_buffer_size(buffers: BuffersRefMut<'_, R, S>, new_size: usize) {
match buffers {
BuffersRefMut::NoBuffers | BuffersRefMut::SendOnly(_) => {}
BuffersRefMut::Both { recv, send: _ } | BuffersRefMut::RecvOnly(recv) => {
recv.request_capacity(new_size)
}
BuffersRefMut::Sizes(BufferSizes { receive, send: _ }) => *receive = new_size,
}
}
fn allowed_range() -> (usize, usize) {
R::capacity_range()
}
fn get_buffer_size(buffers: BuffersRefMut<'_, R, S>) -> Option<usize> {
match buffers {
BuffersRefMut::NoBuffers | BuffersRefMut::SendOnly(_) => None,
BuffersRefMut::Both { recv, send: _ } | BuffersRefMut::RecvOnly(recv) => {
Some(recv.target_capacity())
}
BuffersRefMut::Sizes(BufferSizes { receive, send: _ }) => Some(*receive),
}
}
}
fn get_buffers_mut<I: DualStackIpExt, CC: TcpContext<I, BC>, BC: TcpBindingsContext>(
state: &mut TcpSocketState<I, CC::WeakDeviceId, BC>,
converter: MaybeDualStack<CC::DualStackConverter, CC::SingleStackConverter>,
) -> BuffersRefMut<'_, BC::ReceiveBuffer, BC::SendBuffer> {
match &mut state.socket_state {
TcpSocketStateInner::Unbound(Unbound { buffer_sizes, .. }) => {
BuffersRefMut::Sizes(buffer_sizes)
}
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let state = match converter {
MaybeDualStack::NotDualStack(converter) => {
let (conn, _addr) = converter.convert(conn);
&mut conn.state
}
MaybeDualStack::DualStack(converter) => match converter.convert(conn) {
EitherStack::ThisStack((conn, _addr)) => &mut conn.state,
EitherStack::OtherStack((conn, _addr)) => &mut conn.state,
},
};
state.buffers_mut()
}
TcpSocketStateInner::Bound(BoundSocketState::Listener((maybe_listener, _, _))) => {
match maybe_listener {
MaybeListener::Bound(BoundState { buffer_sizes, .. })
| MaybeListener::Listener(Listener { buffer_sizes, .. }) => {
BuffersRefMut::Sizes(buffer_sizes)
}
}
}
}
}
fn set_buffer_size<
Which: AccessBufferSize<BC::ReceiveBuffer, BC::SendBuffer>,
I: DualStackIpExt,
BC: TcpBindingsContext,
CC: TcpContext<I, BC>,
>(
core_ctx: &mut CC,
id: &TcpSocketId<I, CC::WeakDeviceId, BC>,
size: usize,
) {
let (min, max) = Which::allowed_range();
let size = size.clamp(min, max);
core_ctx.with_socket_mut_and_converter(id, |state, converter| {
Which::set_buffer_size(get_buffers_mut::<I, CC, BC>(state, converter), size)
})
}
fn get_buffer_size<
Which: AccessBufferSize<BC::ReceiveBuffer, BC::SendBuffer>,
I: DualStackIpExt,
BC: TcpBindingsContext,
CC: TcpContext<I, BC>,
>(
core_ctx: &mut CC,
id: &TcpSocketId<I, CC::WeakDeviceId, BC>,
) -> Option<usize> {
core_ctx.with_socket_mut_and_converter(id, |state, converter| {
Which::get_buffer_size(get_buffers_mut::<I, CC, BC>(state, converter))
})
}
#[derive(Debug, GenericOverIp)]
#[generic_over_ip()]
pub enum SetDeviceError {
Conflict,
Unroutable,
ZoneChange,
}
#[derive(Debug, GenericOverIp)]
#[generic_over_ip()]
pub enum AcceptError {
WouldBlock,
NotSupported,
}
#[derive(Debug, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub enum ListenError {
ListenerExists,
NotSupported,
}
#[derive(Debug, GenericOverIp, Eq, PartialEq)]
#[generic_over_ip()]
pub struct NoConnection;
#[derive(Debug, GenericOverIp)]
#[generic_over_ip()]
pub enum SetReuseAddrError {
AddrInUse,
NotSupported,
}
#[derive(Debug, Error, GenericOverIp)]
#[generic_over_ip()]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub enum ConnectError {
#[error("unable to allocate a port")]
NoPort,
#[error("no route to remote host")]
NoRoute,
#[error(transparent)]
Zone(#[from] ZonedAddressError),
#[error("there is already a connection at the address requested")]
ConnectionExists,
#[error("called connect on a listener")]
Listener,
#[error("the handshake has already started")]
Pending,
#[error("the handshake is completed")]
Completed,
#[error("the handshake is aborted")]
Aborted,
}
#[derive(Debug, Error, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub enum BindError {
#[error("the socket was already bound")]
AlreadyBound,
#[error(transparent)]
LocalAddressError(#[from] LocalAddressError),
}
#[derive(GenericOverIp)]
#[generic_over_ip()]
pub enum OriginalDestinationError {
NotConnected,
NotFound,
UnspecifiedDestinationAddr,
UnspecifiedDestinationPort,
}
#[derive(GenericOverIp)]
#[generic_over_ip(I, Ip)]
pub struct DemuxSocketId<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes>(
I::DemuxSocketId<D, BT>,
);
trait DemuxStateAccessor<I: DualStackIpExt, CC: DeviceIdContext<AnyDevice>, BT: TcpBindingsTypes> {
fn update_demux_state_for_connect<
O,
E,
F: FnOnce(
&I::DemuxSocketId<CC::WeakDeviceId, BT>,
&mut DemuxState<I, CC::WeakDeviceId, BT>,
) -> Result<O, E>,
>(
self,
core_ctx: &mut CC,
cb: F,
) -> Result<O, E>;
}
struct SingleStackDemuxStateAccessor<
'a,
I: DualStackIpExt,
CC: DeviceIdContext<AnyDevice>,
BT: TcpBindingsTypes,
>(
&'a I::DemuxSocketId<CC::WeakDeviceId, BT>,
Option<ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, CC::WeakDeviceId>>,
);
impl<'a, I, CC, BT> DemuxStateAccessor<I, CC, BT> for SingleStackDemuxStateAccessor<'a, I, CC, BT>
where
I: DualStackIpExt,
BT: TcpBindingsTypes,
CC: DeviceIdContext<AnyDevice> + TcpDemuxContext<I, CC::WeakDeviceId, BT>,
{
fn update_demux_state_for_connect<
O,
E,
F: FnOnce(
&I::DemuxSocketId<CC::WeakDeviceId, BT>,
&mut DemuxState<I, CC::WeakDeviceId, BT>,
) -> Result<O, E>,
>(
self,
core_ctx: &mut CC,
cb: F,
) -> Result<O, E> {
core_ctx.with_demux_mut(|demux| {
let Self(demux_id, listener_addr) = self;
let output = cb(demux_id, demux)?;
if let Some(listener_addr) = listener_addr {
demux
.socketmap
.listeners_mut()
.remove(demux_id, &listener_addr)
.expect("failed to remove a bound socket");
}
Ok(output)
})
}
}
struct DualStackDemuxStateAccessor<
'a,
I: DualStackIpExt,
CC: DeviceIdContext<AnyDevice>,
BT: TcpBindingsTypes,
>(
&'a TcpSocketId<I, CC::WeakDeviceId, BT>,
DualStackTuple<I, Option<ListenerAddr<ListenerIpAddr<I::Addr, NonZeroU16>, CC::WeakDeviceId>>>,
);
impl<'a, SockI, WireI, CC, BT> DemuxStateAccessor<WireI, CC, BT>
for DualStackDemuxStateAccessor<'a, SockI, CC, BT>
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BT: TcpBindingsTypes,
CC: DeviceIdContext<AnyDevice>
+ TcpDualStackContext<SockI, CC::WeakDeviceId, BT>
+ TcpDemuxContext<WireI, CC::WeakDeviceId, BT>
+ TcpDemuxContext<WireI::OtherVersion, CC::WeakDeviceId, BT>,
{
fn update_demux_state_for_connect<
O,
E,
F: FnOnce(
&WireI::DemuxSocketId<CC::WeakDeviceId, BT>,
&mut DemuxState<WireI, CC::WeakDeviceId, BT>,
) -> Result<O, E>,
>(
self,
core_ctx: &mut CC,
cb: F,
) -> Result<O, E> {
let Self(id, local_addr) = self;
let (DemuxSocketId(wire_id), DemuxSocketId(other_id)) =
core_ctx.dual_stack_demux_id(id.clone()).cast::<WireI>().into_inner();
let (wire_local_addr, other_local_addr) = local_addr.cast::<WireI>().into_inner();
let output = core_ctx.with_demux_mut(|wire_demux: &mut DemuxState<WireI, _, _>| {
let output = cb(&wire_id, wire_demux)?;
if let Some(wire_local_addr) = wire_local_addr {
wire_demux
.socketmap
.listeners_mut()
.remove(&wire_id, &wire_local_addr)
.expect("failed to remove a bound socket");
}
Ok(output)
})?;
if let Some(other_local_addr) = other_local_addr {
core_ctx.with_demux_mut(|other_demux: &mut DemuxState<WireI::OtherVersion, _, _>| {
other_demux
.socketmap
.listeners_mut()
.remove(&other_id, &other_local_addr)
.expect("failed to remove a bound socket");
});
}
Ok(output)
}
}
fn connect_inner<CC, BC, SockI, WireI, Demux>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
sock_id: &TcpSocketId<SockI, CC::WeakDeviceId, BC>,
isn: &IsnGenerator<BC::Instant>,
listener_addr: Option<ListenerAddr<ListenerIpAddr<WireI::Addr, NonZeroU16>, CC::WeakDeviceId>>,
remote_ip: ZonedAddr<SocketIpAddr<WireI::Addr>, CC::DeviceId>,
remote_port: NonZeroU16,
active_open: TakeableRef<'_, BC::ListenerNotifierOrProvidedBuffers>,
buffer_sizes: BufferSizes,
socket_options: SocketOptions,
sharing: SharingState,
demux: Demux,
convert_back_op: impl FnOnce(
Connection<SockI, WireI, CC::WeakDeviceId, BC>,
ConnAddr<ConnIpAddr<WireI::Addr, NonZeroU16, NonZeroU16>, CC::WeakDeviceId>,
) -> SockI::ConnectionAndAddr<CC::WeakDeviceId, BC>,
convert_timer: impl FnOnce(WeakTcpSocketId<SockI, CC::WeakDeviceId, BC>) -> BC::DispatchId,
) -> Result<TcpSocketStateInner<SockI, CC::WeakDeviceId, BC>, ConnectError>
where
SockI: DualStackIpExt,
WireI: DualStackIpExt,
BC: TcpBindingsContext,
CC: TransportIpContext<WireI, BC>
+ DeviceIpSocketHandler<WireI, BC>
+ CounterContext<TcpCounters<SockI>>,
Demux: DemuxStateAccessor<WireI, CC, BC>,
{
let (local_ip, bound_device, local_port) = match listener_addr {
Some(ListenerAddr { ip: ListenerIpAddr { addr, identifier }, device }) => {
(addr.and_then(IpDeviceAddr::new_from_socket_ip_addr), device, Some(identifier))
}
None => (None, None, None),
};
let (remote_ip, device) = remote_ip.resolve_addr_with_device(bound_device)?;
let ip_sock = core_ctx
.new_ip_socket(
bindings_ctx,
device.as_ref().map(|d| d.as_ref()),
local_ip,
remote_ip,
IpProto::Tcp.into(),
&socket_options.ip_options,
)
.map_err(|err| match err {
IpSockCreationError::Route(_) => ConnectError::NoRoute,
})?;
let device_mms = core_ctx.get_mms(bindings_ctx, &ip_sock, &socket_options.ip_options).map_err(
|_err: ip::socket::MmsError| {
ConnectError::NoRoute
},
)?;
let conn_addr =
demux.update_demux_state_for_connect(core_ctx, |demux_id, DemuxState { socketmap }| {
let local_port = local_port.map_or_else(
|| match netstack3_base::simple_randomized_port_alloc(
&mut bindings_ctx.rng(),
&Some(SocketIpAddr::from(*ip_sock.local_ip())),
&TcpPortAlloc(socketmap),
&Some(remote_port),
) {
Some(port) => {
Ok(NonZeroU16::new(port).expect("ephemeral ports must be non-zero"))
}
None => Err(ConnectError::NoPort),
},
Ok,
)?;
let conn_addr = ConnAddr {
ip: ConnIpAddr {
local: (SocketIpAddr::from(*ip_sock.local_ip()), local_port),
remote: (*ip_sock.remote_ip(), remote_port),
},
device: ip_sock.device().cloned(),
};
let _entry = socketmap
.conns_mut()
.try_insert(conn_addr.clone(), sharing, demux_id.clone())
.map_err(|(err, _sharing)| match err {
InsertError::Exists | InsertError::ShadowerExists => {
ConnectError::ConnectionExists
}
InsertError::ShadowAddrExists | InsertError::IndirectConflict => {
panic!("failed to insert connection: {:?}", err)
}
})?;
Ok::<_, ConnectError>(conn_addr)
})?;
let isn = isn.generate::<SocketIpAddr<WireI::Addr>, NonZeroU16>(
bindings_ctx.now(),
conn_addr.ip.local,
conn_addr.ip.remote,
);
let now = bindings_ctx.now();
let mms = Mss::from_mms::<WireI>(device_mms).ok_or(ConnectError::NoRoute)?;
let active_open = active_open.take();
Ok((move || {
let (syn_sent, syn) = Closed::<Initial>::connect(
isn,
now,
active_open,
buffer_sizes,
mms,
Mss::default::<WireI>(),
&socket_options,
);
let state = State::<_, BC::ReceiveBuffer, BC::SendBuffer, _>::SynSent(syn_sent);
let poll_send_at = state.poll_send_at().expect("no retrans timer");
send_tcp_segment(
core_ctx,
bindings_ctx,
Some(&sock_id),
Some(&ip_sock),
conn_addr.ip,
syn.into_empty(),
&socket_options.ip_options,
);
let mut timer = bindings_ctx.new_timer(convert_timer(sock_id.downgrade()));
assert_eq!(bindings_ctx.schedule_timer_instant(poll_send_at, &mut timer), None);
let conn = convert_back_op(
Connection {
accept_queue: None,
state,
ip_sock,
defunct: false,
socket_options,
soft_error: None,
handshake_status: HandshakeStatus::Pending,
},
conn_addr,
);
core_ctx.increment(|counters| &counters.active_connection_openings);
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, sharing, timer })
})())
}
#[derive(Clone, Debug, Eq, PartialEq, GenericOverIp)]
#[generic_over_ip(A, IpAddress)]
pub enum SocketInfo<A: IpAddress, D> {
Unbound(UnboundInfo<D>),
Bound(BoundInfo<A, D>),
Connection(ConnectionInfo<A, D>),
}
#[derive(Clone, Debug, Eq, PartialEq, GenericOverIp)]
#[generic_over_ip()]
pub struct UnboundInfo<D> {
pub device: Option<D>,
}
#[derive(Clone, Debug, Eq, PartialEq, GenericOverIp)]
#[generic_over_ip(A, IpAddress)]
pub struct BoundInfo<A: IpAddress, D> {
pub addr: Option<ZonedAddr<SpecifiedAddr<A>, D>>,
pub port: NonZeroU16,
pub device: Option<D>,
}
#[derive(Clone, Debug, Eq, PartialEq, GenericOverIp)]
#[generic_over_ip(A, IpAddress)]
pub struct ConnectionInfo<A: IpAddress, D> {
pub local_addr: SocketAddr<A, D>,
pub remote_addr: SocketAddr<A, D>,
pub device: Option<D>,
}
impl<D: Clone, Extra> From<&'_ Unbound<D, Extra>> for UnboundInfo<D> {
fn from(unbound: &Unbound<D, Extra>) -> Self {
let Unbound {
bound_device: device,
buffer_sizes: _,
socket_options: _,
sharing: _,
socket_extra: _,
} = unbound;
Self { device: device.clone() }
}
}
fn maybe_zoned<A: IpAddress, D: Clone>(
ip: SpecifiedAddr<A>,
device: &Option<D>,
) -> ZonedAddr<SpecifiedAddr<A>, D> {
device
.as_ref()
.and_then(|device| {
AddrAndZone::new(ip, device).map(|az| ZonedAddr::Zoned(az.map_zone(Clone::clone)))
})
.unwrap_or(ZonedAddr::Unzoned(ip))
}
impl<A: IpAddress, D: Clone> From<ListenerAddr<ListenerIpAddr<A, NonZeroU16>, D>>
for BoundInfo<A, D>
{
fn from(addr: ListenerAddr<ListenerIpAddr<A, NonZeroU16>, D>) -> Self {
let ListenerAddr { ip: ListenerIpAddr { addr, identifier }, device } = addr;
let addr = addr.map(|ip| maybe_zoned(ip.into(), &device));
BoundInfo { addr, port: identifier, device }
}
}
impl<A: IpAddress, D: Clone> From<ConnAddr<ConnIpAddr<A, NonZeroU16, NonZeroU16>, D>>
for ConnectionInfo<A, D>
{
fn from(addr: ConnAddr<ConnIpAddr<A, NonZeroU16, NonZeroU16>, D>) -> Self {
let ConnAddr { ip: ConnIpAddr { local, remote }, device } = addr;
let convert = |(ip, port): (SocketIpAddr<A>, NonZeroU16)| SocketAddr {
ip: maybe_zoned(ip.into(), &device),
port,
};
Self { local_addr: convert(local), remote_addr: convert(remote), device }
}
}
impl<CC, BC> HandleableTimer<CC, BC> for TcpTimerId<CC::WeakDeviceId, BC>
where
BC: TcpBindingsContext,
CC: TcpContext<Ipv4, BC>
+ TcpContext<Ipv6, BC>
+ CounterContext<TcpCounters<Ipv4>>
+ CounterContext<TcpCounters<Ipv6>>,
{
fn handle(self, core_ctx: &mut CC, bindings_ctx: &mut BC, _: BC::UniqueTimerId) {
let ctx_pair = CtxPair { core_ctx, bindings_ctx };
match self {
TcpTimerId::V4(conn_id) => TcpApi::new(ctx_pair).handle_timer(conn_id),
TcpTimerId::V6(conn_id) => TcpApi::new(ctx_pair).handle_timer(conn_id),
}
}
}
fn send_tcp_segment<'a, WireI, SockI, CC, BC, D>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
socket_id: Option<&TcpSocketId<SockI, D, BC>>,
ip_sock: Option<&IpSock<WireI, D>>,
conn_addr: ConnIpAddr<WireI::Addr, NonZeroU16, NonZeroU16>,
segment: Segment<<BC::SendBuffer as SendBuffer>::Payload<'a>>,
ip_sock_options: &TcpIpSockOptions,
) where
WireI: IpExt,
SockI: IpExt + DualStackIpExt,
CC: CounterContext<TcpCounters<SockI>>
+ IpSocketHandler<WireI, BC, DeviceId = D::Strong, WeakDeviceId = D>,
BC: TcpBindingsTypes,
D: WeakDeviceIdentifier,
{
let control = segment.header.control;
let result = match ip_sock {
Some(ip_sock) => {
let body = tcp_serialize_segment(segment, conn_addr);
core_ctx
.send_ip_packet(bindings_ctx, ip_sock, body, ip_sock_options)
.map_err(|err| IpSockCreateAndSendError::Send(err))
}
None => {
let ConnIpAddr { local: (local_ip, _), remote: (remote_ip, _) } = conn_addr;
core_ctx.send_oneshot_ip_packet(
bindings_ctx,
None,
IpDeviceAddr::new_from_socket_ip_addr(local_ip),
remote_ip,
IpProto::Tcp.into(),
ip_sock_options,
|_addr| tcp_serialize_segment(segment, conn_addr),
)
}
};
match result {
Ok(()) => {
core_ctx.increment(|counters| &counters.segments_sent);
match control {
None => {}
Some(Control::RST) => core_ctx.increment(|counters| &counters.resets_sent),
Some(Control::SYN) => core_ctx.increment(|counters| &counters.syns_sent),
Some(Control::FIN) => core_ctx.increment(|counters| &counters.fins_sent),
}
}
Err(err) => {
core_ctx.increment(|counters| &counters.segment_send_errors);
match socket_id {
Some(socket_id) => debug!("{:?}: failed to send segment: {:?}", socket_id, err),
None => debug!("TCP: failed to send segment: {:?}", err),
}
}
}
}
#[cfg(test)]
mod tests {
use alloc::rc::Rc;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
use alloc::{format, vec};
use core::cell::RefCell;
use core::ffi::CStr;
use core::num::NonZeroU16;
use core::time::Duration;
use const_unwrap::const_unwrap_option;
use ip_test_macro::ip_test;
use net_declare::net_ip_v6;
use net_types::ip::{Ip, Ipv4, Ipv6, Ipv6SourceAddr, Mtu};
use net_types::{LinkLocalAddr, Witness};
use netstack3_base::sync::{DynDebugReferences, Mutex};
use netstack3_base::testutil::{
new_rng, run_with_many_seeds, set_logger_for_test, FakeAtomicInstant, FakeCoreCtx,
FakeCryptoRng, FakeDeviceId, FakeInstant, FakeNetwork, FakeNetworkSpec, FakeStrongDeviceId,
FakeTimerCtx, FakeTimerId, FakeWeakDeviceId, InstantAndData, MultipleDevicesId,
PendingFrameData, StepResult, TestIpExt, WithFakeFrameContext, WithFakeTimerContext,
};
use netstack3_base::{
ContextProvider, IcmpIpExt, Icmpv4ErrorCode, Icmpv6ErrorCode, Instant as _, InstantContext,
LinkDevice, Mms, ReferenceNotifiers, StrongDeviceIdentifier, Uninstantiable,
UninstantiableWrapper,
};
use netstack3_filter::{TransportPacketSerializer, Tuple};
use netstack3_ip::device::IpDeviceStateIpExt;
use netstack3_ip::nud::testutil::FakeLinkResolutionNotifier;
use netstack3_ip::nud::LinkResolutionContext;
use netstack3_ip::socket::testutil::{FakeDeviceConfig, FakeDualStackIpSocketCtx};
use netstack3_ip::socket::{IpSockSendError, MmsError, RouteResolutionOptions, SendOptions};
use netstack3_ip::testutil::DualStackSendIpPacketMeta;
use netstack3_ip::{
BaseTransportIpContext, HopLimits, IpTransportContext, ReceiveIpPacketMeta,
};
use packet::{Buf, BufferMut, ParseBuffer as _};
use packet_formats::icmp::{Icmpv4DestUnreachableCode, Icmpv6DestUnreachableCode};
use packet_formats::tcp::{TcpParseArgs, TcpSegment};
use rand::Rng as _;
use test_case::test_case;
use super::*;
use crate::internal::base::{ConnectionError, DEFAULT_FIN_WAIT2_TIMEOUT};
use crate::internal::buffer::testutil::{
ClientBuffers, ProvidedBuffers, RingBuffer, TestSendBuffer, WriteBackClientBuffers,
};
use crate::internal::buffer::BufferLimits;
use crate::internal::state::{TimeWait, MSL};
trait TcpTestIpExt: DualStackIpExt + TestIpExt + IpDeviceStateIpExt + DualStackIpExt {
type SingleStackConverter: SingleStackConverter<
Self,
FakeWeakDeviceId<FakeDeviceId>,
TcpBindingsCtx<FakeDeviceId>,
>;
type DualStackConverter: DualStackConverter<
Self,
FakeWeakDeviceId<FakeDeviceId>,
TcpBindingsCtx<FakeDeviceId>,
>;
fn recv_src_addr(addr: Self::Addr) -> Self::RecvSrcAddr;
fn converter() -> MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter>;
}
trait TcpTestBindingsTypes<D: StrongDeviceIdentifier>:
TcpBindingsTypes<DispatchId = TcpTimerId<D::Weak, Self>> + Sized
{
}
impl<D, BT> TcpTestBindingsTypes<D> for BT
where
BT: TcpBindingsTypes<DispatchId = TcpTimerId<D::Weak, Self>> + Sized,
D: StrongDeviceIdentifier,
{
}
struct FakeTcpState<I: TcpTestIpExt, D: FakeStrongDeviceId, BT: TcpBindingsTypes> {
isn_generator: Rc<IsnGenerator<BT::Instant>>,
demux: Rc<RefCell<DemuxState<I, D::Weak, BT>>>,
all_sockets: TcpSocketSet<I, D::Weak, BT>,
counters: TcpCounters<I>,
}
impl<I, D, BT> Default for FakeTcpState<I, D, BT>
where
I: TcpTestIpExt,
D: FakeStrongDeviceId,
BT: TcpBindingsTypes,
BT::Instant: Default,
{
fn default() -> Self {
Self {
isn_generator: Default::default(),
all_sockets: Default::default(),
demux: Rc::new(RefCell::new(DemuxState { socketmap: Default::default() })),
counters: Default::default(),
}
}
}
struct FakeDualStackTcpState<D: FakeStrongDeviceId, BT: TcpBindingsTypes> {
v4: FakeTcpState<Ipv4, D, BT>,
v6: FakeTcpState<Ipv6, D, BT>,
}
impl<D, BT> Default for FakeDualStackTcpState<D, BT>
where
D: FakeStrongDeviceId,
BT: TcpBindingsTypes,
BT::Instant: Default,
{
fn default() -> Self {
Self { v4: Default::default(), v6: Default::default() }
}
}
type InnerCoreCtx<D> =
FakeCoreCtx<FakeDualStackIpSocketCtx<D>, DualStackSendIpPacketMeta<D>, D>;
struct TcpCoreCtx<D: FakeStrongDeviceId, BT: TcpBindingsTypes> {
tcp: FakeDualStackTcpState<D, BT>,
ip_socket_ctx: InnerCoreCtx<D>,
}
impl<D: FakeStrongDeviceId, BT: TcpBindingsTypes> ContextProvider for TcpCoreCtx<D, BT> {
type Context = Self;
fn context(&mut self) -> &mut Self::Context {
self
}
}
impl<D, BT> DeviceIdContext<AnyDevice> for TcpCoreCtx<D, BT>
where
D: FakeStrongDeviceId,
BT: TcpBindingsTypes,
{
type DeviceId = D;
type WeakDeviceId = FakeWeakDeviceId<D>;
}
type TcpCtx<D> = CtxPair<TcpCoreCtx<D, TcpBindingsCtx<D>>, TcpBindingsCtx<D>>;
struct FakeTcpNetworkSpec<D: FakeStrongDeviceId>(PhantomData<D>, Never);
impl<D: FakeStrongDeviceId> FakeNetworkSpec for FakeTcpNetworkSpec<D> {
type Context = TcpCtx<D>;
type TimerId = TcpTimerId<D::Weak, TcpBindingsCtx<D>>;
type SendMeta = DualStackSendIpPacketMeta<D>;
type RecvMeta = DualStackSendIpPacketMeta<D>;
fn handle_frame(ctx: &mut Self::Context, meta: Self::RecvMeta, buffer: Buf<Vec<u8>>) {
let TcpCtx { core_ctx, bindings_ctx } = ctx;
match meta {
DualStackSendIpPacketMeta::V4(meta) => {
<TcpIpTransportContext as IpTransportContext<Ipv4, _, _>>::receive_ip_packet(
core_ctx,
bindings_ctx,
&meta.device,
Ipv4::recv_src_addr(*meta.src_ip),
meta.dst_ip,
buffer,
ReceiveIpPacketMeta::default(),
)
.expect("failed to deliver bytes");
}
DualStackSendIpPacketMeta::V6(meta) => {
<TcpIpTransportContext as IpTransportContext<Ipv6, _, _>>::receive_ip_packet(
core_ctx,
bindings_ctx,
&meta.device,
Ipv6::recv_src_addr(*meta.src_ip),
meta.dst_ip,
buffer,
ReceiveIpPacketMeta::default(),
)
.expect("failed to deliver bytes");
}
}
}
fn handle_timer(ctx: &mut Self::Context, dispatch: Self::TimerId, _: FakeTimerId) {
match dispatch {
TcpTimerId::V4(id) => ctx.tcp_api().handle_timer(id),
TcpTimerId::V6(id) => ctx.tcp_api().handle_timer(id),
}
}
fn process_queues(_ctx: &mut Self::Context) -> bool {
false
}
fn fake_frames(ctx: &mut Self::Context) -> &mut impl WithFakeFrameContext<Self::SendMeta> {
&mut ctx.core_ctx.ip_socket_ctx.frames
}
}
impl<D: FakeStrongDeviceId> WithFakeTimerContext<TcpTimerId<D::Weak, TcpBindingsCtx<D>>>
for TcpCtx<D>
{
fn with_fake_timer_ctx<
O,
F: FnOnce(&FakeTimerCtx<TcpTimerId<D::Weak, TcpBindingsCtx<D>>>) -> O,
>(
&self,
f: F,
) -> O {
let Self { core_ctx: _, bindings_ctx } = self;
f(&bindings_ctx.timers)
}
fn with_fake_timer_ctx_mut<
O,
F: FnOnce(&mut FakeTimerCtx<TcpTimerId<D::Weak, TcpBindingsCtx<D>>>) -> O,
>(
&mut self,
f: F,
) -> O {
let Self { core_ctx: _, bindings_ctx } = self;
f(&mut bindings_ctx.timers)
}
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
struct TcpBindingsCtx<D: FakeStrongDeviceId> {
rng: FakeCryptoRng,
timers: FakeTimerCtx<TcpTimerId<D::Weak, Self>>,
}
impl<D: FakeStrongDeviceId> ContextProvider for TcpBindingsCtx<D> {
type Context = Self;
fn context(&mut self) -> &mut Self::Context {
self
}
}
impl<D: LinkDevice + FakeStrongDeviceId> LinkResolutionContext<D> for TcpBindingsCtx<D> {
type Notifier = FakeLinkResolutionNotifier<D>;
}
impl<D: FakeStrongDeviceId> TimerBindingsTypes for TcpBindingsCtx<D> {
type Timer = <FakeTimerCtx<TcpTimerId<D::Weak, Self>> as TimerBindingsTypes>::Timer;
type DispatchId =
<FakeTimerCtx<TcpTimerId<D::Weak, Self>> as TimerBindingsTypes>::DispatchId;
type UniqueTimerId =
<FakeTimerCtx<TcpTimerId<D::Weak, Self>> as TimerBindingsTypes>::UniqueTimerId;
}
impl<D: FakeStrongDeviceId> InstantBindingsTypes for TcpBindingsCtx<D> {
type Instant = FakeInstant;
type AtomicInstant = FakeAtomicInstant;
}
impl<D: FakeStrongDeviceId> InstantContext for TcpBindingsCtx<D> {
fn now(&self) -> FakeInstant {
self.timers.now()
}
}
impl<D: FakeStrongDeviceId> TimerContext for TcpBindingsCtx<D> {
fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
self.timers.new_timer(id)
}
fn schedule_timer_instant(
&mut self,
time: Self::Instant,
timer: &mut Self::Timer,
) -> Option<Self::Instant> {
self.timers.schedule_timer_instant(time, timer)
}
fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
self.timers.cancel_timer(timer)
}
fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
self.timers.scheduled_instant(timer)
}
fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
self.timers.unique_timer_id(timer)
}
}
impl<D: FakeStrongDeviceId> TracingContext for TcpBindingsCtx<D> {
type DurationScope = ();
fn duration(&self, _: &'static CStr) {}
}
impl<D: FakeStrongDeviceId> ReferenceNotifiers for TcpBindingsCtx<D> {
type ReferenceReceiver<T: 'static> = Never;
type ReferenceNotifier<T: Send + 'static> = Never;
fn new_reference_notifier<T: Send + 'static>(
debug_references: DynDebugReferences,
) -> (Self::ReferenceNotifier<T>, Self::ReferenceReceiver<T>) {
panic!(
"can't create deferred reference notifiers for type {}: \
debug_references={debug_references:?}",
core::any::type_name::<T>()
);
}
}
impl<D: FakeStrongDeviceId> DeferredResourceRemovalContext for TcpBindingsCtx<D> {
fn defer_removal<T: Send + 'static>(&mut self, receiver: Self::ReferenceReceiver<T>) {
match receiver {}
}
}
impl<D: FakeStrongDeviceId> RngContext for TcpBindingsCtx<D> {
type Rng<'a> = &'a mut FakeCryptoRng;
fn rng(&mut self) -> Self::Rng<'_> {
&mut self.rng
}
}
impl<D: FakeStrongDeviceId> TcpBindingsTypes for TcpBindingsCtx<D> {
type ReceiveBuffer = Arc<Mutex<RingBuffer>>;
type SendBuffer = TestSendBuffer;
type ReturnedBuffers = ClientBuffers;
type ListenerNotifierOrProvidedBuffers = ProvidedBuffers;
fn new_passive_open_buffers(
buffer_sizes: BufferSizes,
) -> (Self::ReceiveBuffer, Self::SendBuffer, Self::ReturnedBuffers) {
let client = ClientBuffers::new(buffer_sizes);
(
Arc::clone(&client.receive),
TestSendBuffer::new(Arc::clone(&client.send), RingBuffer::default()),
client,
)
}
fn default_buffer_sizes() -> BufferSizes {
BufferSizes::default()
}
}
impl<I, D, BC> DeviceIpSocketHandler<I, BC> for TcpCoreCtx<D, BC>
where
I: TcpTestIpExt,
D: FakeStrongDeviceId,
BC: TcpTestBindingsTypes<D>,
{
fn get_mms<O>(
&mut self,
_bindings_ctx: &mut BC,
_ip_sock: &IpSock<I, Self::WeakDeviceId>,
_options: &O,
) -> Result<Mms, MmsError>
where
O: RouteResolutionOptions<I>,
{
Ok(Mms::from_mtu::<I>(Mtu::new(1500), 0).unwrap())
}
}
impl<I, D, BC> BaseTransportIpContext<I, BC> for TcpCoreCtx<D, BC>
where
I: TcpTestIpExt,
D: FakeStrongDeviceId,
BC: TcpTestBindingsTypes<D>,
{
type DevicesWithAddrIter<'a>
= <InnerCoreCtx<D> as BaseTransportIpContext<I, BC>>::DevicesWithAddrIter<'a>
where
Self: 'a;
fn with_devices_with_assigned_addr<O, F: FnOnce(Self::DevicesWithAddrIter<'_>) -> O>(
&mut self,
addr: SpecifiedAddr<I::Addr>,
cb: F,
) -> O {
BaseTransportIpContext::<I, BC>::with_devices_with_assigned_addr(
&mut self.ip_socket_ctx,
addr,
cb,
)
}
fn get_default_hop_limits(&mut self, device: Option<&Self::DeviceId>) -> HopLimits {
BaseTransportIpContext::<I, BC>::get_default_hop_limits(&mut self.ip_socket_ctx, device)
}
fn get_original_destination(&mut self, tuple: &Tuple<I>) -> Option<(I::Addr, u16)> {
BaseTransportIpContext::<I, BC>::get_original_destination(
&mut self.ip_socket_ctx,
tuple,
)
}
}
impl<I: TcpTestIpExt, D: FakeStrongDeviceId, BC: TcpTestBindingsTypes<D>> IpSocketHandler<I, BC>
for TcpCoreCtx<D, BC>
{
fn new_ip_socket<O>(
&mut self,
bindings_ctx: &mut BC,
device: Option<EitherDeviceId<&Self::DeviceId, &Self::WeakDeviceId>>,
local_ip: Option<IpDeviceAddr<I::Addr>>,
remote_ip: SocketIpAddr<I::Addr>,
proto: I::Proto,
options: &O,
) -> Result<IpSock<I, Self::WeakDeviceId>, IpSockCreationError>
where
O: RouteResolutionOptions<I>,
{
IpSocketHandler::<I, BC>::new_ip_socket(
&mut self.ip_socket_ctx,
bindings_ctx,
device,
local_ip,
remote_ip,
proto,
options,
)
}
fn send_ip_packet<S, O>(
&mut self,
bindings_ctx: &mut BC,
socket: &IpSock<I, Self::WeakDeviceId>,
body: S,
options: &O,
) -> Result<(), IpSockSendError>
where
S: TransportPacketSerializer<I>,
S::Buffer: BufferMut,
O: SendOptions<I> + RouteResolutionOptions<I>,
{
self.ip_socket_ctx.send_ip_packet(bindings_ctx, socket, body, options)
}
fn confirm_reachable<O>(
&mut self,
bindings_ctx: &mut BC,
socket: &IpSock<I, Self::WeakDeviceId>,
options: &O,
) where
O: RouteResolutionOptions<I>,
{
self.ip_socket_ctx.confirm_reachable(bindings_ctx, socket, options)
}
}
impl<D, BC> TcpDemuxContext<Ipv4, D::Weak, BC> for TcpCoreCtx<D, BC>
where
D: FakeStrongDeviceId,
BC: TcpTestBindingsTypes<D>,
{
type IpTransportCtx<'a> = Self;
fn with_demux<O, F: FnOnce(&DemuxState<Ipv4, D::Weak, BC>) -> O>(&mut self, cb: F) -> O {
cb(&self.tcp.v4.demux.borrow())
}
fn with_demux_mut<O, F: FnOnce(&mut DemuxState<Ipv4, D::Weak, BC>) -> O>(
&mut self,
cb: F,
) -> O {
cb(&mut self.tcp.v4.demux.borrow_mut())
}
}
impl<D, BC> TcpDemuxContext<Ipv6, D::Weak, BC> for TcpCoreCtx<D, BC>
where
D: FakeStrongDeviceId,
BC: TcpTestBindingsTypes<D>,
{
type IpTransportCtx<'a> = Self;
fn with_demux<O, F: FnOnce(&DemuxState<Ipv6, D::Weak, BC>) -> O>(&mut self, cb: F) -> O {
cb(&self.tcp.v6.demux.borrow())
}
fn with_demux_mut<O, F: FnOnce(&mut DemuxState<Ipv6, D::Weak, BC>) -> O>(
&mut self,
cb: F,
) -> O {
cb(&mut self.tcp.v6.demux.borrow_mut())
}
}
impl<I, D, BT> CoreTimerContext<WeakTcpSocketId<I, D::Weak, BT>, BT> for TcpCoreCtx<D, BT>
where
I: DualStackIpExt,
D: FakeStrongDeviceId,
BT: TcpTestBindingsTypes<D>,
{
fn convert_timer(dispatch_id: WeakTcpSocketId<I, D::Weak, BT>) -> BT::DispatchId {
dispatch_id.into()
}
}
impl<D: FakeStrongDeviceId, BC: TcpTestBindingsTypes<D>> TcpContext<Ipv6, BC>
for TcpCoreCtx<D, BC>
{
type ThisStackIpTransportAndDemuxCtx<'a> = Self;
type SingleStackIpTransportAndDemuxCtx<'a> = UninstantiableWrapper<Self>;
type SingleStackConverter = Uninstantiable;
type DualStackIpTransportAndDemuxCtx<'a> = Self;
type DualStackConverter = ();
fn with_all_sockets_mut<
O,
F: FnOnce(&mut TcpSocketSet<Ipv6, Self::WeakDeviceId, BC>) -> O,
>(
&mut self,
cb: F,
) -> O {
cb(&mut self.tcp.v6.all_sockets)
}
fn for_each_socket<
F: FnMut(
&TcpSocketId<Ipv6, Self::WeakDeviceId, BC>,
&TcpSocketState<Ipv6, Self::WeakDeviceId, BC>,
),
>(
&mut self,
_cb: F,
) {
unimplemented!()
}
fn with_socket_mut_isn_transport_demux<
O,
F: for<'a> FnOnce(
MaybeDualStack<
(&'a mut Self::DualStackIpTransportAndDemuxCtx<'a>, Self::DualStackConverter),
(
&'a mut Self::SingleStackIpTransportAndDemuxCtx<'a>,
Self::SingleStackConverter,
),
>,
&mut TcpSocketState<Ipv6, Self::WeakDeviceId, BC>,
&IsnGenerator<BC::Instant>,
) -> O,
>(
&mut self,
id: &TcpSocketId<Ipv6, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
let isn = Rc::clone(&self.tcp.v6.isn_generator);
cb(MaybeDualStack::DualStack((self, ())), id.get_mut().deref_mut(), isn.deref())
}
fn with_socket_and_converter<
O,
F: FnOnce(
&TcpSocketState<Ipv6, Self::WeakDeviceId, BC>,
MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter>,
) -> O,
>(
&mut self,
id: &TcpSocketId<Ipv6, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
cb(id.get_mut().deref_mut(), MaybeDualStack::DualStack(()))
}
}
impl<D: FakeStrongDeviceId, BC: TcpTestBindingsTypes<D>> TcpContext<Ipv4, BC>
for TcpCoreCtx<D, BC>
{
type ThisStackIpTransportAndDemuxCtx<'a> = Self;
type SingleStackIpTransportAndDemuxCtx<'a> = Self;
type SingleStackConverter = ();
type DualStackIpTransportAndDemuxCtx<'a> = UninstantiableWrapper<Self>;
type DualStackConverter = Uninstantiable;
fn with_all_sockets_mut<
O,
F: FnOnce(&mut TcpSocketSet<Ipv4, Self::WeakDeviceId, BC>) -> O,
>(
&mut self,
cb: F,
) -> O {
cb(&mut self.tcp.v4.all_sockets)
}
fn for_each_socket<
F: FnMut(
&TcpSocketId<Ipv4, Self::WeakDeviceId, BC>,
&TcpSocketState<Ipv4, Self::WeakDeviceId, BC>,
),
>(
&mut self,
_cb: F,
) {
unimplemented!()
}
fn with_socket_mut_isn_transport_demux<
O,
F: for<'a> FnOnce(
MaybeDualStack<
(&'a mut Self::DualStackIpTransportAndDemuxCtx<'a>, Self::DualStackConverter),
(
&'a mut Self::SingleStackIpTransportAndDemuxCtx<'a>,
Self::SingleStackConverter,
),
>,
&mut TcpSocketState<Ipv4, Self::WeakDeviceId, BC>,
&IsnGenerator<BC::Instant>,
) -> O,
>(
&mut self,
id: &TcpSocketId<Ipv4, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
let isn: Rc<IsnGenerator<<BC as InstantBindingsTypes>::Instant>> =
Rc::clone(&self.tcp.v4.isn_generator);
cb(MaybeDualStack::NotDualStack((self, ())), id.get_mut().deref_mut(), isn.deref())
}
fn with_socket_and_converter<
O,
F: FnOnce(
&TcpSocketState<Ipv4, Self::WeakDeviceId, BC>,
MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter>,
) -> O,
>(
&mut self,
id: &TcpSocketId<Ipv4, Self::WeakDeviceId, BC>,
cb: F,
) -> O {
cb(id.get_mut().deref_mut(), MaybeDualStack::NotDualStack(()))
}
}
impl<D: FakeStrongDeviceId, BT: TcpTestBindingsTypes<D>>
TcpDualStackContext<Ipv6, FakeWeakDeviceId<D>, BT> for TcpCoreCtx<D, BT>
{
type DualStackIpTransportCtx<'a> = Self;
fn other_demux_id_converter(&self) -> impl DualStackDemuxIdConverter<Ipv6> {
Ipv6SocketIdToIpv4DemuxIdConverter
}
fn dual_stack_enabled(&self, ip_options: &Ipv6Options) -> bool {
ip_options.dual_stack_enabled
}
fn set_dual_stack_enabled(&self, ip_options: &mut Ipv6Options, value: bool) {
ip_options.dual_stack_enabled = value;
}
fn with_both_demux_mut<
O,
F: FnOnce(
&mut DemuxState<Ipv6, FakeWeakDeviceId<D>, BT>,
&mut DemuxState<Ipv4, FakeWeakDeviceId<D>, BT>,
) -> O,
>(
&mut self,
cb: F,
) -> O {
cb(&mut self.tcp.v6.demux.borrow_mut(), &mut self.tcp.v4.demux.borrow_mut())
}
}
impl<I: Ip, D: FakeStrongDeviceId, BT: TcpTestBindingsTypes<D>> CounterContext<TcpCounters<I>>
for TcpCoreCtx<D, BT>
{
fn with_counters<O, F: FnOnce(&TcpCounters<I>) -> O>(&self, cb: F) -> O {
let counters = I::map_ip((), |()| &self.tcp.v4.counters, |()| &self.tcp.v6.counters);
cb(counters)
}
}
impl<D, BT> TcpCoreCtx<D, BT>
where
D: FakeStrongDeviceId,
BT: TcpBindingsTypes,
BT::Instant: Default,
{
fn with_ip_socket_ctx_state(state: FakeDualStackIpSocketCtx<D>) -> Self {
Self { tcp: Default::default(), ip_socket_ctx: FakeCoreCtx::with_state(state) }
}
}
impl TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>> {
fn new<I: TcpTestIpExt>(
addr: SpecifiedAddr<I::Addr>,
peer: SpecifiedAddr<I::Addr>,
_prefix: u8,
) -> Self {
Self::with_ip_socket_ctx_state(FakeDualStackIpSocketCtx::new(core::iter::once(
FakeDeviceConfig {
device: FakeDeviceId,
local_ips: vec![addr],
remote_ips: vec![peer],
},
)))
}
}
impl TcpCoreCtx<MultipleDevicesId, TcpBindingsCtx<MultipleDevicesId>> {
fn new_multiple_devices() -> Self {
Self::with_ip_socket_ctx_state(FakeDualStackIpSocketCtx::new(core::iter::empty::<
FakeDeviceConfig<MultipleDevicesId, SpecifiedAddr<IpAddr>>,
>()))
}
}
const LOCAL: &'static str = "local";
const REMOTE: &'static str = "remote";
const PORT_1: NonZeroU16 = const_unwrap_option(NonZeroU16::new(42));
const PORT_2: NonZeroU16 = const_unwrap_option(NonZeroU16::new(43));
impl TcpTestIpExt for Ipv4 {
type SingleStackConverter = ();
type DualStackConverter = Uninstantiable;
fn converter() -> MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter> {
MaybeDualStack::NotDualStack(())
}
fn recv_src_addr(addr: Self::Addr) -> Self::RecvSrcAddr {
addr
}
}
impl TcpTestIpExt for Ipv6 {
type SingleStackConverter = Uninstantiable;
type DualStackConverter = ();
fn converter() -> MaybeDualStack<Self::DualStackConverter, Self::SingleStackConverter> {
MaybeDualStack::DualStack(())
}
fn recv_src_addr(addr: Self::Addr) -> Self::RecvSrcAddr {
Ipv6SourceAddr::new(addr).unwrap()
}
}
type TcpTestNetwork = FakeNetwork<
FakeTcpNetworkSpec<FakeDeviceId>,
&'static str,
fn(
&'static str,
DualStackSendIpPacketMeta<FakeDeviceId>,
) -> Vec<(
&'static str,
DualStackSendIpPacketMeta<FakeDeviceId>,
Option<core::time::Duration>,
)>,
>;
fn new_test_net<I: TcpTestIpExt>() -> TcpTestNetwork {
FakeTcpNetworkSpec::new_network(
[
(
LOCAL,
TcpCtx {
core_ctx: TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
),
bindings_ctx: TcpBindingsCtx::default(),
},
),
(
REMOTE,
TcpCtx {
core_ctx: TcpCoreCtx::new::<I>(
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.subnet.prefix(),
),
bindings_ctx: TcpBindingsCtx::default(),
},
),
],
move |net, meta: DualStackSendIpPacketMeta<_>| {
if net == LOCAL {
alloc::vec![(REMOTE, meta, None)]
} else {
alloc::vec![(LOCAL, meta, None)]
}
},
)
}
impl<I: DualStackIpExt, D: WeakDeviceIdentifier, BT: TcpBindingsTypes> TcpSocketId<I, D, BT> {
fn get(&self) -> impl Deref<Target = TcpSocketState<I, D, BT>> + '_ {
let Self(rc) = self;
rc.read()
}
fn get_mut(&self) -> impl DerefMut<Target = TcpSocketState<I, D, BT>> + '_ {
let Self(rc) = self;
rc.write()
}
}
fn assert_this_stack_conn<
'a,
I: DualStackIpExt,
BC: TcpBindingsContext,
CC: TcpContext<I, BC>,
>(
conn: &'a I::ConnectionAndAddr<CC::WeakDeviceId, BC>,
converter: &MaybeDualStack<CC::DualStackConverter, CC::SingleStackConverter>,
) -> &'a (
Connection<I, I, CC::WeakDeviceId, BC>,
ConnAddr<ConnIpAddr<I::Addr, NonZeroU16, NonZeroU16>, CC::WeakDeviceId>,
) {
match converter {
MaybeDualStack::NotDualStack(nds) => nds.convert(conn),
MaybeDualStack::DualStack(ds) => {
assert_matches!(ds.convert(conn), EitherStack::ThisStack(conn) => conn)
}
}
}
trait TcpApiExt: ContextPair + Sized {
fn tcp_api<I: Ip>(&mut self) -> TcpApi<I, &mut Self> {
TcpApi::new(self)
}
}
impl<O> TcpApiExt for O where O: ContextPair + Sized {}
struct BindConfig {
client_port: Option<NonZeroU16>,
server_port: NonZeroU16,
client_reuse_addr: bool,
send_test_data: bool,
}
fn bind_listen_connect_accept_inner<I: TcpTestIpExt>(
listen_addr: I::Addr,
BindConfig { client_port, server_port, client_reuse_addr, send_test_data }: BindConfig,
seed: u128,
drop_rate: f64,
) -> (
TcpTestNetwork,
TcpSocketId<I, FakeWeakDeviceId<FakeDeviceId>, TcpBindingsCtx<FakeDeviceId>>,
Arc<Mutex<Vec<u8>>>,
TcpSocketId<I, FakeWeakDeviceId<FakeDeviceId>, TcpBindingsCtx<FakeDeviceId>>,
)
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
let mut net = new_test_net::<I>();
let mut rng = new_rng(seed);
let mut maybe_drop_frame =
|_: &mut TcpCtx<_>, meta: DualStackSendIpPacketMeta<_>, buffer: Buf<Vec<u8>>| {
let x: f64 = rng.gen();
(x > drop_rate).then_some((meta, buffer))
};
let backlog = NonZeroUsize::new(1).unwrap();
let server = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let server = api.create(Default::default());
api.bind(
&server,
SpecifiedAddr::new(listen_addr).map(|a| ZonedAddr::Unzoned(a)),
Some(server_port),
)
.expect("failed to bind the server socket");
api.listen(&server, backlog).expect("can listen");
server
});
let client_ends = WriteBackClientBuffers::default();
let client = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(ProvidedBuffers::Buffers(client_ends.clone()));
if client_reuse_addr {
api.set_reuseaddr(&socket, true).expect("can set");
}
if let Some(port) = client_port {
api.bind(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(port))
.expect("failed to bind the client socket")
}
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), server_port)
.expect("failed to connect");
socket
});
if drop_rate == 0.0 {
let _: StepResult = net.step();
assert_matches!(
&server.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(Listener {
accept_queue,
..
}), ..))) => {
assert_eq!(accept_queue.ready_len(), 0);
assert_eq!(accept_queue.pending_len(), 1);
}
);
net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
assert_matches!(api.accept(&server), Err(AcceptError::WouldBlock));
});
}
net.run_until_idle_with(&mut maybe_drop_frame);
let (accepted, addr, accepted_ends) = net.with_context(REMOTE, |ctx| {
ctx.tcp_api::<I>().accept(&server).expect("failed to accept")
});
if let Some(port) = client_port {
assert_eq!(
addr,
SocketAddr { ip: ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip), port: port }
);
} else {
assert_eq!(addr.ip, ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip));
}
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
assert_eq!(
api.connect(
&client,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)),
server_port,
),
Ok(())
);
});
let assert_connected = |conn_id: &TcpSocketId<I, _, _>| {
assert_matches!(
&conn_id.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
accept_queue: None,
state: State::Established(_),
ip_sock: _,
defunct: false,
socket_options: _,
soft_error: None,
handshake_status: HandshakeStatus::Completed { reported: true },
}
);
})
};
assert_connected(&client);
assert_connected(&accepted);
let ClientBuffers { send: client_snd_end, receive: client_rcv_end } =
client_ends.0.as_ref().lock().take().unwrap();
let ClientBuffers { send: accepted_snd_end, receive: accepted_rcv_end } = accepted_ends;
if send_test_data {
for snd_end in [client_snd_end.clone(), accepted_snd_end] {
snd_end.lock().extend_from_slice(b"Hello");
}
for (c, id) in [(LOCAL, &client), (REMOTE, &accepted)] {
net.with_context(c, |ctx| ctx.tcp_api::<I>().do_send(id))
}
net.run_until_idle_with(&mut maybe_drop_frame);
for rcv_end in [client_rcv_end, accepted_rcv_end] {
assert_eq!(
rcv_end.lock().read_with(|avail| {
let avail = avail.concat();
assert_eq!(avail, b"Hello");
avail.len()
}),
5
);
}
}
assert_matches!(
&server.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Listener((MaybeListener::Listener(l),..))) => {
assert_eq!(l, &Listener::new(
backlog,
BufferSizes::default(),
SocketOptions::default(),
Default::default()
));
}
);
net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
assert_eq!(api.shutdown(&server, ShutdownType::Receive), Ok(false));
api.close(server);
});
(net, client, client_snd_end, accepted)
}
#[test]
fn test_socket_addr_display() {
assert_eq!(
format!(
"{}",
SocketAddr {
ip: maybe_zoned(
SpecifiedAddr::new(Ipv4Addr::new([192, 168, 0, 1]))
.expect("failed to create specified addr"),
&None::<usize>,
),
port: NonZeroU16::new(1024).expect("failed to create NonZeroU16"),
}
),
String::from("192.168.0.1:1024"),
);
assert_eq!(
format!(
"{}",
SocketAddr {
ip: maybe_zoned(
SpecifiedAddr::new(Ipv6Addr::new([0x2001, 0xDB8, 0, 0, 0, 0, 0, 1]))
.expect("failed to create specified addr"),
&None::<usize>,
),
port: NonZeroU16::new(1024).expect("failed to create NonZeroU16"),
}
),
String::from("[2001:db8::1]:1024")
);
assert_eq!(
format!(
"{}",
SocketAddr {
ip: maybe_zoned(
SpecifiedAddr::new(Ipv6Addr::new([0xFE80, 0, 0, 0, 0, 0, 0, 1]))
.expect("failed to create specified addr"),
&Some(42),
),
port: NonZeroU16::new(1024).expect("failed to create NonZeroU16"),
}
),
String::from("[fe80::1%42]:1024")
);
}
#[ip_test(I)]
#[test_case(BindConfig { client_port: None, server_port: PORT_1, client_reuse_addr: false, send_test_data: true }, I::UNSPECIFIED_ADDRESS)]
#[test_case(BindConfig { client_port: Some(PORT_1), server_port: PORT_1, client_reuse_addr: false, send_test_data: true }, I::UNSPECIFIED_ADDRESS)]
#[test_case(BindConfig { client_port: None, server_port: PORT_1, client_reuse_addr: true, send_test_data: true }, I::UNSPECIFIED_ADDRESS)]
#[test_case(BindConfig { client_port: Some(PORT_1), server_port: PORT_1, client_reuse_addr: true, send_test_data: true }, I::UNSPECIFIED_ADDRESS)]
#[test_case(BindConfig { client_port: None, server_port: PORT_1, client_reuse_addr: false, send_test_data: true }, *<I as TestIpExt>::TEST_ADDRS.remote_ip)]
#[test_case(BindConfig { client_port: Some(PORT_1), server_port: PORT_1, client_reuse_addr: false, send_test_data: true }, *<I as TestIpExt>::TEST_ADDRS.remote_ip)]
#[test_case(BindConfig { client_port: None, server_port: PORT_1, client_reuse_addr: true, send_test_data: true }, *<I as TestIpExt>::TEST_ADDRS.remote_ip)]
#[test_case(BindConfig { client_port: Some(PORT_1), server_port: PORT_1, client_reuse_addr: true, send_test_data: true }, *<I as TestIpExt>::TEST_ADDRS.remote_ip)]
fn bind_listen_connect_accept<I: TcpTestIpExt>(bind_config: BindConfig, listen_addr: I::Addr)
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let (mut net, _client, _client_snd_end, _accepted) =
bind_listen_connect_accept_inner::<I>(listen_addr, bind_config, 0, 0.0);
struct ExpectedCounters {
tx: u64,
rx: u64,
passive_open: u64,
active_open: u64,
}
let mut assert_counters =
|context_name: &'static str, ExpectedCounters { tx, rx, passive_open, active_open }| {
net.with_context(context_name, |ctx| {
ctx.core_ctx.with_counters(|counters: &TcpCounters<I>| {
let c = counters.as_ref();
assert_eq!(c.segment_send_errors.get(), 0, "{}", context_name);
assert_eq!(c.segments_sent.get(), tx, "{}", context_name);
assert_eq!(c.invalid_segments_received.get(), 0, "{}", context_name);
assert_eq!(c.valid_segments_received.get(), rx, "{}", context_name);
assert_eq!(c.received_segments_dispatched.get(), rx, "{}", context_name);
assert_eq!(
c.active_connection_openings.get(),
active_open,
"{}",
context_name
);
assert_eq!(
c.passive_connection_openings.get(),
passive_open,
"{}",
context_name
);
assert_eq!(c.failed_connection_attempts.get(), 0, "{}", context_name);
assert_eq!(c.syns_sent.get(), 1);
assert_eq!(c.syns_received.get(), 1);
})
})
};
assert_counters(LOCAL, ExpectedCounters { tx: 4, rx: 3, passive_open: 0, active_open: 1 });
assert_counters(REMOTE, ExpectedCounters { tx: 3, rx: 4, passive_open: 1, active_open: 0 });
}
#[ip_test(I)]
#[test_case(*<I as TestIpExt>::TEST_ADDRS.local_ip; "same addr")]
#[test_case(I::UNSPECIFIED_ADDRESS; "any addr")]
fn bind_conflict<I: TcpTestIpExt>(conflict_addr: I::Addr)
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
set_logger_for_test();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let s1 = api.create(Default::default());
let s2 = api.create(Default::default());
api.bind(&s1, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1))
.expect("first bind should succeed");
assert_matches!(
api.bind(&s2, SpecifiedAddr::new(conflict_addr).map(ZonedAddr::Unzoned), Some(PORT_1)),
Err(BindError::LocalAddressError(LocalAddressError::AddressInUse))
);
api.bind(&s2, SpecifiedAddr::new(conflict_addr).map(ZonedAddr::Unzoned), Some(PORT_2))
.expect("able to rebind to a free address");
}
#[ip_test(I)]
#[test_case(const_unwrap_option(NonZeroU16::new(u16::MAX)), Ok(const_unwrap_option(NonZeroU16::new(u16::MAX))); "ephemeral available")]
#[test_case(const_unwrap_option(NonZeroU16::new(100)), Err(LocalAddressError::FailedToAllocateLocalPort);
"no ephemeral available")]
fn bind_picked_port_all_others_taken<I: TcpTestIpExt>(
available_port: NonZeroU16,
expected_result: Result<NonZeroU16, LocalAddressError>,
) where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
for port in 1..=u16::MAX {
let port = NonZeroU16::new(port).unwrap();
if port == available_port {
continue;
}
let socket = api.create(Default::default());
api.bind(&socket, None, Some(port)).expect("uncontested bind");
api.listen(&socket, const_unwrap_option(NonZeroUsize::new(1))).expect("can listen");
}
let socket = api.create(Default::default());
let result = api.bind(&socket, None, None).map(|()| {
assert_matches!(
api.get_info(&socket),
SocketInfo::Bound(bound) => bound.port
)
});
assert_eq!(result, expected_result.map_err(From::from));
api.close(socket);
let socket = api.create(Default::default());
let result =
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), available_port);
assert_eq!(result, Err(ConnectError::NoPort));
}
#[ip_test(I)]
fn bind_to_non_existent_address<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let unbound = api.create(Default::default());
assert_matches!(
api.bind(&unbound, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), None),
Err(BindError::LocalAddressError(LocalAddressError::AddressMismatch))
);
assert_matches!(unbound.get().deref().socket_state, TcpSocketStateInner::Unbound(_));
}
#[test]
fn bind_addr_requires_zone() {
let local_ip = LinkLocalAddr::new(net_ip_v6!("fe80::1")).unwrap().into_specified();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(
Ipv6::TEST_ADDRS.local_ip,
Ipv6::TEST_ADDRS.remote_ip,
Ipv6::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<Ipv6>();
let unbound = api.create(Default::default());
assert_matches!(
api.bind(&unbound, Some(ZonedAddr::Unzoned(local_ip)), None),
Err(BindError::LocalAddressError(LocalAddressError::Zone(
ZonedAddressError::RequiredZoneNotProvided
)))
);
assert_matches!(unbound.get().deref().socket_state, TcpSocketStateInner::Unbound(_));
}
#[test]
fn connect_bound_requires_zone() {
let ll_ip = LinkLocalAddr::new(net_ip_v6!("fe80::1")).unwrap().into_specified();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(
Ipv6::TEST_ADDRS.local_ip,
Ipv6::TEST_ADDRS.remote_ip,
Ipv6::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
api.bind(&socket, None, None).expect("bind succeeds");
assert_matches!(
api.connect(&socket, Some(ZonedAddr::Unzoned(ll_ip)), PORT_1,),
Err(ConnectError::Zone(ZonedAddressError::RequiredZoneNotProvided))
);
assert_matches!(socket.get().deref().socket_state, TcpSocketStateInner::Bound(_));
}
#[ip_test(I)]
fn bind_listen_on_same_port_different_addrs<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
set_logger_for_test();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::with_ip_socket_ctx_state(
FakeDualStackIpSocketCtx::new(core::iter::once(FakeDeviceConfig {
device: FakeDeviceId,
local_ips: vec![I::TEST_ADDRS.local_ip, I::TEST_ADDRS.remote_ip],
remote_ips: vec![],
})),
));
let mut api = ctx.tcp_api::<I>();
let s1 = api.create(Default::default());
api.bind(&s1, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1)).unwrap();
api.listen(&s1, NonZeroUsize::MIN).unwrap();
let s2 = api.create(Default::default());
api.bind(&s2, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), Some(PORT_1)).unwrap();
api.listen(&s2, NonZeroUsize::MIN).unwrap();
}
#[ip_test(I)]
#[test_case(None, None; "both any addr")]
#[test_case(None, Some(<I as TestIpExt>::TEST_ADDRS.local_ip); "any then specified")]
#[test_case(Some(<I as TestIpExt>::TEST_ADDRS.local_ip), None; "specified then any")]
#[test_case(
Some(<I as TestIpExt>::TEST_ADDRS.local_ip),
Some(<I as TestIpExt>::TEST_ADDRS.local_ip);
"both specified"
)]
fn cannot_listen_on_same_port_with_shadowed_address<I: TcpTestIpExt>(
first: Option<SpecifiedAddr<I::Addr>>,
second: Option<SpecifiedAddr<I::Addr>>,
) where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
set_logger_for_test();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::with_ip_socket_ctx_state(
FakeDualStackIpSocketCtx::new(core::iter::once(FakeDeviceConfig {
device: FakeDeviceId,
local_ips: vec![I::TEST_ADDRS.local_ip],
remote_ips: vec![],
})),
));
let mut api = ctx.tcp_api::<I>();
let s1 = api.create(Default::default());
api.set_reuseaddr(&s1, true).unwrap();
api.bind(&s1, first.map(ZonedAddr::Unzoned), Some(PORT_1)).unwrap();
let s2 = api.create(Default::default());
api.set_reuseaddr(&s2, true).unwrap();
api.bind(&s2, second.map(ZonedAddr::Unzoned), Some(PORT_1)).unwrap();
api.listen(&s1, NonZeroUsize::MIN).unwrap();
assert_eq!(api.listen(&s2, NonZeroUsize::MIN), Err(ListenError::ListenerExists));
}
#[test]
fn connect_unbound_picks_link_local_source_addr() {
set_logger_for_test();
let client_ip = SpecifiedAddr::new(net_ip_v6!("fe80::1")).unwrap();
let server_ip = SpecifiedAddr::new(net_ip_v6!("1:2:3:4::")).unwrap();
let mut net = FakeTcpNetworkSpec::new_network(
[
(LOCAL, TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(client_ip, server_ip, 0))),
(REMOTE, TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(server_ip, client_ip, 0))),
],
|net, meta| {
if net == LOCAL {
alloc::vec![(REMOTE, meta, None)]
} else {
alloc::vec![(LOCAL, meta, None)]
}
},
);
const PORT: NonZeroU16 = const_unwrap_option(NonZeroU16::new(100));
let client_connection = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api();
let socket: TcpSocketId<Ipv6, _, _> = api.create(Default::default());
api.connect(&socket, Some(ZonedAddr::Unzoned(server_ip)), PORT).expect("can connect");
socket
});
net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
api.bind(&socket, None, Some(PORT)).expect("failed to bind the client socket");
let _listener = api.listen(&socket, NonZeroUsize::MIN).expect("can listen");
});
net.run_until_idle();
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api();
assert_eq!(
api.connect(&client_connection, Some(ZonedAddr::Unzoned(server_ip)), PORT),
Ok(())
);
let info = assert_matches!(
api.get_info(&client_connection),
SocketInfo::Connection(info) => info
);
let (local_ip, remote_ip) = assert_matches!(
info,
ConnectionInfo {
local_addr: SocketAddr { ip: local_ip, port: _ },
remote_addr: SocketAddr { ip: remote_ip, port: PORT },
device: Some(FakeWeakDeviceId(FakeDeviceId))
} => (local_ip, remote_ip)
);
assert_eq!(
local_ip,
ZonedAddr::Zoned(
AddrAndZone::new(client_ip, FakeWeakDeviceId(FakeDeviceId)).unwrap()
)
);
assert_eq!(remote_ip, ZonedAddr::Unzoned(server_ip));
assert_matches!(
api.set_device(&client_connection, None),
Err(SetDeviceError::ZoneChange)
);
});
}
#[test]
fn accept_connect_picks_link_local_addr() {
set_logger_for_test();
let server_ip = SpecifiedAddr::new(net_ip_v6!("fe80::1")).unwrap();
let client_ip = SpecifiedAddr::new(net_ip_v6!("1:2:3:4::")).unwrap();
let mut net = FakeTcpNetworkSpec::new_network(
[
(LOCAL, TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(server_ip, client_ip, 0))),
(REMOTE, TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(client_ip, server_ip, 0))),
],
|net, meta| {
if net == LOCAL {
alloc::vec![(REMOTE, meta, None)]
} else {
alloc::vec![(LOCAL, meta, None)]
}
},
);
const PORT: NonZeroU16 = const_unwrap_option(NonZeroU16::new(100));
let server_listener = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket: TcpSocketId<Ipv6, _, _> = api.create(Default::default());
api.bind(&socket, None, Some(PORT)).expect("failed to bind the client socket");
api.listen(&socket, NonZeroUsize::MIN).expect("can listen");
socket
});
let client_connection = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
api.connect(
&socket,
Some(ZonedAddr::Zoned(AddrAndZone::new(server_ip, FakeDeviceId).unwrap())),
PORT,
)
.expect("failed to open a connection");
socket
});
net.run_until_idle();
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api();
let (server_connection, _addr, _buffers) =
api.accept(&server_listener).expect("connection is waiting");
let info = assert_matches!(
api.get_info(&server_connection),
SocketInfo::Connection(info) => info
);
let (local_ip, remote_ip) = assert_matches!(
info,
ConnectionInfo {
local_addr: SocketAddr { ip: local_ip, port: PORT },
remote_addr: SocketAddr { ip: remote_ip, port: _ },
device: Some(FakeWeakDeviceId(FakeDeviceId))
} => (local_ip, remote_ip)
);
assert_eq!(
local_ip,
ZonedAddr::Zoned(
AddrAndZone::new(server_ip, FakeWeakDeviceId(FakeDeviceId)).unwrap()
)
);
assert_eq!(remote_ip, ZonedAddr::Unzoned(client_ip));
assert_matches!(
api.set_device(&server_connection, None),
Err(SetDeviceError::ZoneChange)
);
});
net.with_context(REMOTE, |ctx| {
assert_eq!(
ctx.tcp_api().connect(
&client_connection,
Some(ZonedAddr::Zoned(AddrAndZone::new(server_ip, FakeDeviceId).unwrap())),
PORT,
),
Ok(())
);
});
}
#[ip_test(I)]
fn connect_reset<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let mut net = new_test_net::<I>();
let client = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let conn = api.create(Default::default());
api.bind(&conn, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1))
.expect("failed to bind the client socket");
api.connect(&conn, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), PORT_1)
.expect("failed to connect");
conn
});
let _: StepResult = net.step();
net.collect_frames();
assert_matches!(
&net.iter_pending_frames().collect::<Vec<_>>()[..],
[InstantAndData(_instant, PendingFrameData {
dst_context: _,
meta,
frame,
})] => {
let mut buffer = Buf::new(frame, ..);
match I::VERSION {
IpVersion::V4 => {
let meta = assert_matches!(meta, DualStackSendIpPacketMeta::V4(v4) => v4);
let parsed = buffer.parse_with::<_, TcpSegment<_>>(
TcpParseArgs::new(*meta.src_ip, *meta.dst_ip)
).expect("failed to parse");
assert!(parsed.rst())
}
IpVersion::V6 => {
let meta = assert_matches!(meta, DualStackSendIpPacketMeta::V6(v6) => v6);
let parsed = buffer.parse_with::<_, TcpSegment<_>>(
TcpParseArgs::new(*meta.src_ip, *meta.dst_ip)
).expect("failed to parse");
assert!(parsed.rst())
}
}
});
net.run_until_idle();
assert_matches!(
&client.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
accept_queue: None,
state: State::Closed(Closed {
reason: Some(ConnectionError::ConnectionRefused)
}),
ip_sock: _,
defunct: false,
socket_options: _,
soft_error: None,
handshake_status: HandshakeStatus::Aborted,
}
);
});
net.with_context(LOCAL, |ctx| {
assert_matches!(
ctx.tcp_api().connect(
&client,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)),
PORT_1
),
Err(ConnectError::Aborted)
);
});
}
#[ip_test(I)]
fn retransmission<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
run_with_many_seeds(|seed| {
let (_net, _client, _client_snd_end, _accepted) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: true,
},
seed,
0.2,
);
});
}
const LOCAL_PORT: NonZeroU16 = const_unwrap_option(NonZeroU16::new(1845));
#[ip_test(I)]
fn listener_with_bound_device_conflict<I: TcpTestIpExt>()
where
TcpCoreCtx<MultipleDevicesId, TcpBindingsCtx<MultipleDevicesId>>:
TcpContext<I, TcpBindingsCtx<MultipleDevicesId>>,
{
set_logger_for_test();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new_multiple_devices());
let mut api = ctx.tcp_api::<I>();
let sock_a = api.create(Default::default());
assert_matches!(api.set_device(&sock_a, Some(MultipleDevicesId::A),), Ok(()));
api.bind(&sock_a, None, Some(LOCAL_PORT)).expect("bind should succeed");
api.listen(&sock_a, const_unwrap_option(NonZeroUsize::new(10))).expect("can listen");
let socket = api.create(Default::default());
assert_matches!(
api.bind(&socket, None, Some(LOCAL_PORT)),
Err(BindError::LocalAddressError(LocalAddressError::AddressInUse))
);
assert_matches!(api.set_device(&socket, Some(MultipleDevicesId::B),), Ok(()));
api.bind(&socket, None, Some(LOCAL_PORT)).expect("no conflict");
}
#[test_case(None)]
#[test_case(Some(MultipleDevicesId::B); "other")]
fn set_bound_device_listener_on_zoned_addr(set_device: Option<MultipleDevicesId>) {
set_logger_for_test();
let ll_addr = LinkLocalAddr::new(Ipv6::LINK_LOCAL_UNICAST_SUBNET.network()).unwrap();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::with_ip_socket_ctx_state(
FakeDualStackIpSocketCtx::new(MultipleDevicesId::all().into_iter().map(|device| {
FakeDeviceConfig {
device,
local_ips: vec![ll_addr.into_specified()],
remote_ips: vec![ll_addr.into_specified()],
}
})),
));
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
api.bind(
&socket,
Some(ZonedAddr::Zoned(
AddrAndZone::new(ll_addr.into_specified(), MultipleDevicesId::A).unwrap(),
)),
Some(LOCAL_PORT),
)
.expect("bind should succeed");
assert_matches!(api.set_device(&socket, set_device), Err(SetDeviceError::ZoneChange));
}
#[test_case(None)]
#[test_case(Some(MultipleDevicesId::B); "other")]
fn set_bound_device_connected_to_zoned_addr(set_device: Option<MultipleDevicesId>) {
set_logger_for_test();
let ll_addr = LinkLocalAddr::new(Ipv6::LINK_LOCAL_UNICAST_SUBNET.network()).unwrap();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::with_ip_socket_ctx_state(
FakeDualStackIpSocketCtx::new(MultipleDevicesId::all().into_iter().map(|device| {
FakeDeviceConfig {
device,
local_ips: vec![ll_addr.into_specified()],
remote_ips: vec![ll_addr.into_specified()],
}
})),
));
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
api.connect(
&socket,
Some(ZonedAddr::Zoned(
AddrAndZone::new(ll_addr.into_specified(), MultipleDevicesId::A).unwrap(),
)),
LOCAL_PORT,
)
.expect("connect should succeed");
assert_matches!(api.set_device(&socket, set_device), Err(SetDeviceError::ZoneChange));
}
#[ip_test(I)]
#[test_case(*<I as TestIpExt>::TEST_ADDRS.local_ip, true; "specified bound")]
#[test_case(I::UNSPECIFIED_ADDRESS, true; "unspecified bound")]
#[test_case(*<I as TestIpExt>::TEST_ADDRS.local_ip, false; "specified listener")]
#[test_case(I::UNSPECIFIED_ADDRESS, false; "unspecified listener")]
fn bound_socket_info<I: TcpTestIpExt>(ip_addr: I::Addr, listen: bool)
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
let (addr, port) = (SpecifiedAddr::new(ip_addr).map(ZonedAddr::Unzoned), PORT_1);
api.bind(&socket, addr, Some(port)).expect("bind should succeed");
if listen {
api.listen(&socket, const_unwrap_option(NonZeroUsize::new(25))).expect("can listen");
}
let info = api.get_info(&socket);
assert_eq!(
info,
SocketInfo::Bound(BoundInfo {
addr: addr.map(|a| a.map_zone(FakeWeakDeviceId)),
port,
device: None
})
);
}
#[ip_test(I)]
fn connection_info<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let local = SocketAddr { ip: ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip), port: PORT_1 };
let remote = SocketAddr { ip: ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip), port: PORT_2 };
let socket = api.create(Default::default());
api.bind(&socket, Some(local.ip), Some(local.port)).expect("bind should succeed");
api.connect(&socket, Some(remote.ip), remote.port).expect("connect should succeed");
assert_eq!(
api.get_info(&socket),
SocketInfo::Connection(ConnectionInfo {
local_addr: local.map_zone(FakeWeakDeviceId),
remote_addr: remote.map_zone(FakeWeakDeviceId),
device: None,
}),
);
}
#[test_case(true; "any")]
#[test_case(false; "link local")]
fn accepted_connection_info_zone(listen_any: bool) {
set_logger_for_test();
let client_ip = SpecifiedAddr::new(net_ip_v6!("fe80::1")).unwrap();
let server_ip = SpecifiedAddr::new(net_ip_v6!("fe80::2")).unwrap();
let mut net = FakeTcpNetworkSpec::new_network(
[
(
LOCAL,
TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(
server_ip,
client_ip,
Ipv6::LINK_LOCAL_UNICAST_SUBNET.prefix(),
)),
),
(
REMOTE,
TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(
client_ip,
server_ip,
Ipv6::LINK_LOCAL_UNICAST_SUBNET.prefix(),
)),
),
],
move |net, meta: DualStackSendIpPacketMeta<_>| {
if net == LOCAL {
alloc::vec![(REMOTE, meta, None)]
} else {
alloc::vec![(LOCAL, meta, None)]
}
},
);
let local_server = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
let device = FakeDeviceId;
let bind_addr = match listen_any {
true => None,
false => Some(ZonedAddr::Zoned(AddrAndZone::new(server_ip, device).unwrap())),
};
api.bind(&socket, bind_addr, Some(PORT_1)).expect("failed to bind the client socket");
api.listen(&socket, const_unwrap_option(NonZeroUsize::new(1))).expect("can listen");
socket
});
let _remote_client = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
let device = FakeDeviceId;
api.connect(
&socket,
Some(ZonedAddr::Zoned(AddrAndZone::new(server_ip, device).unwrap())),
PORT_1,
)
.expect("failed to connect");
socket
});
net.run_until_idle();
let ConnectionInfo { remote_addr, local_addr, device } = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api();
let (server_conn, _addr, _buffers) =
api.accept(&local_server).expect("connection is available");
assert_matches!(
api.get_info(&server_conn),
SocketInfo::Connection(info) => info
)
});
let device = assert_matches!(device, Some(device) => device);
assert_eq!(
local_addr,
SocketAddr {
ip: ZonedAddr::Zoned(AddrAndZone::new(server_ip, device).unwrap()),
port: PORT_1
}
);
let SocketAddr { ip: remote_ip, port: _ } = remote_addr;
assert_eq!(remote_ip, ZonedAddr::Zoned(AddrAndZone::new(client_ip, device).unwrap()));
}
#[test]
fn bound_connection_info_zoned_addrs() {
let local_ip = LinkLocalAddr::new(net_ip_v6!("fe80::1")).unwrap().into_specified();
let remote_ip = LinkLocalAddr::new(net_ip_v6!("fe80::2")).unwrap().into_specified();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<Ipv6>(
local_ip,
remote_ip,
Ipv6::LINK_LOCAL_UNICAST_SUBNET.prefix(),
));
let local_addr = SocketAddr {
ip: ZonedAddr::Zoned(AddrAndZone::new(local_ip, FakeDeviceId).unwrap()),
port: PORT_1,
};
let remote_addr = SocketAddr {
ip: ZonedAddr::Zoned(AddrAndZone::new(remote_ip, FakeDeviceId).unwrap()),
port: PORT_2,
};
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
api.bind(&socket, Some(local_addr.ip), Some(local_addr.port)).expect("bind should succeed");
assert_eq!(
api.get_info(&socket),
SocketInfo::Bound(BoundInfo {
addr: Some(local_addr.ip.map_zone(FakeWeakDeviceId)),
port: local_addr.port,
device: Some(FakeWeakDeviceId(FakeDeviceId))
})
);
api.connect(&socket, Some(remote_addr.ip), remote_addr.port)
.expect("connect should succeed");
assert_eq!(
api.get_info(&socket),
SocketInfo::Connection(ConnectionInfo {
local_addr: local_addr.map_zone(FakeWeakDeviceId),
remote_addr: remote_addr.map_zone(FakeWeakDeviceId),
device: Some(FakeWeakDeviceId(FakeDeviceId))
})
);
}
#[ip_test(I)]
#[test_case(true, 2 * MSL; "peer calls close")]
#[test_case(false, DEFAULT_FIN_WAIT2_TIMEOUT; "peer doesn't call close")]
fn connection_close_peer_calls_close<I: TcpTestIpExt>(
peer_calls_close: bool,
expected_time_to_close: Duration,
) where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let (mut net, local, _local_snd_end, remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
let weak_local = local.downgrade();
let close_called = net.with_context(LOCAL, |ctx| {
ctx.tcp_api().close(local);
ctx.bindings_ctx.now()
});
while {
assert!(!net.step().is_idle());
let is_fin_wait_2 = {
let local = weak_local.upgrade().unwrap();
let state = local.get();
let state = assert_matches!(
&state.deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state,
..
} => state
)
}
);
matches!(state, State::FinWait2(_))
};
!is_fin_wait_2
} {}
let weak_remote = remote.downgrade();
if peer_calls_close {
net.with_context(REMOTE, |ctx| {
ctx.tcp_api().close(remote);
});
}
net.run_until_idle();
net.with_context(LOCAL, |TcpCtx { core_ctx: _, bindings_ctx }| {
assert_eq!(
bindings_ctx.now().checked_duration_since(close_called).unwrap(),
expected_time_to_close
);
assert_eq!(weak_local.upgrade(), None);
});
if peer_calls_close {
assert_eq!(weak_remote.upgrade(), None);
}
}
#[ip_test(I)]
fn connection_shutdown_then_close_peer_doesnt_call_close<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let (mut net, local, _local_snd_end, _remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
net.with_context(LOCAL, |ctx| {
assert_eq!(ctx.tcp_api().shutdown(&local, ShutdownType::Send), Ok(true));
});
loop {
assert!(!net.step().is_idle());
let is_fin_wait_2 = {
let state = local.get();
let state = assert_matches!(
&state.deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state, ..
} => state
)});
matches!(state, State::FinWait2(_))
};
if is_fin_wait_2 {
break;
}
}
let weak_local = local.downgrade();
net.with_context(LOCAL, |ctx| {
ctx.tcp_api().close(local);
});
net.run_until_idle();
assert_eq!(weak_local.upgrade(), None);
}
#[ip_test(I)]
fn connection_shutdown_then_close<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let (mut net, local, _local_snd_end, remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
for (name, id) in [(LOCAL, &local), (REMOTE, &remote)] {
net.with_context(name, |ctx| {
let mut api = ctx.tcp_api();
assert_eq!(
api.shutdown(id,ShutdownType::Send),
Ok(true)
);
assert_matches!(
&id.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::FinWait1(_),
..
}
);
});
assert_eq!(
api.shutdown(id,ShutdownType::Send),
Ok(true)
);
});
}
net.run_until_idle();
for (name, id) in [(LOCAL, local), (REMOTE, remote)] {
net.with_context(name, |ctx| {
assert_matches!(
&id.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::Closed(_),
..
}
);
});
let weak_id = id.downgrade();
ctx.tcp_api().close(id);
assert_eq!(weak_id.upgrade(), None)
});
}
}
#[ip_test(I)]
fn remove_unbound<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let unbound = api.create(Default::default());
let weak_unbound = unbound.downgrade();
api.close(unbound);
assert_eq!(weak_unbound.upgrade(), None);
}
#[ip_test(I)]
fn remove_bound<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.bind(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), None)
.expect("bind should succeed");
let weak_socket = socket.downgrade();
api.close(socket);
assert_eq!(weak_socket.upgrade(), None);
}
#[ip_test(I)]
fn shutdown_listener<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let mut net = new_test_net::<I>();
let local_listener = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.bind(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1))
.expect("bind should succeed");
api.listen(&socket, NonZeroUsize::new(5).unwrap()).expect("can listen");
socket
});
let remote_connection = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), PORT_1)
.expect("connect should succeed");
socket
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
assert_eq!(
ctx.tcp_api().connect(
&remote_connection,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)),
PORT_1
),
Ok(())
);
});
let second_connection = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), PORT_1)
.expect("connect should succeed");
socket
});
let _: StepResult = net.step();
net.with_context(LOCAL, |TcpCtx { core_ctx: _, bindings_ctx }| {
assert_matches!(bindings_ctx.timers.timers().len(), 1);
});
net.with_context(LOCAL, |ctx| {
assert_eq!(ctx.tcp_api().shutdown(&local_listener, ShutdownType::Receive,), Ok(false));
});
net.with_context(LOCAL, |TcpCtx { core_ctx: _, bindings_ctx }| {
assert_eq!(bindings_ctx.timers.timers().len(), 0);
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
for conn in [&remote_connection, &second_connection] {
assert_eq!(
ctx.tcp_api().get_socket_error(conn),
Some(ConnectionError::ConnectionReset),
)
}
assert_matches!(
&remote_connection.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::Closed(Closed {
reason: Some(ConnectionError::ConnectionReset)
}),
..
}
);
}
);
});
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let new_unbound = api.create(Default::default());
assert_matches!(
api.bind(
&new_unbound,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip,)),
Some(PORT_1),
),
Err(BindError::LocalAddressError(LocalAddressError::AddressInUse))
);
api.listen(&local_listener, NonZeroUsize::new(5).unwrap()).expect("can listen again");
});
let new_remote_connection = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), PORT_1)
.expect("connect should succeed");
socket
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
assert_matches!(
&new_remote_connection.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::Established(_),
..
}
);
});
assert_eq!(
ctx.tcp_api().connect(
&new_remote_connection,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)),
PORT_1,
),
Ok(())
);
});
}
#[ip_test(I)]
fn clamp_buffer_size<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
set_logger_for_test();
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
let (min, max) = <
<TcpBindingsCtx<FakeDeviceId> as TcpBindingsTypes>::SendBuffer as crate::Buffer
>::capacity_range();
api.set_send_buffer_size(&socket, min - 1);
assert_eq!(api.send_buffer_size(&socket), Some(min));
api.set_send_buffer_size(&socket, max + 1);
assert_eq!(api.send_buffer_size(&socket), Some(max));
let (min, max) = <
<TcpBindingsCtx<FakeDeviceId> as TcpBindingsTypes>::ReceiveBuffer as crate::Buffer
>::capacity_range();
api.set_receive_buffer_size(&socket, min - 1);
assert_eq!(api.receive_buffer_size(&socket), Some(min));
api.set_receive_buffer_size(&socket, max + 1);
assert_eq!(api.receive_buffer_size(&socket), Some(max));
}
#[ip_test(I)]
fn set_reuseaddr_unbound<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let first_bound = {
let socket = api.create(Default::default());
api.set_reuseaddr(&socket, true).expect("can set");
api.bind(&socket, None, None).expect("bind succeeds");
socket
};
let _second_bound = {
let socket = api.create(Default::default());
api.set_reuseaddr(&socket, true).expect("can set");
api.bind(&socket, None, None).expect("bind succeeds");
socket
};
api.listen(&first_bound, const_unwrap_option(NonZeroUsize::new(10))).expect("can listen");
}
#[ip_test(I)]
#[test_case([true, true], Ok(()); "allowed with set")]
#[test_case([false, true], Err(LocalAddressError::AddressInUse); "first unset")]
#[test_case([true, false], Err(LocalAddressError::AddressInUse); "second unset")]
#[test_case([false, false], Err(LocalAddressError::AddressInUse); "both unset")]
fn reuseaddr_multiple_bound<I: TcpTestIpExt>(
set_reuseaddr: [bool; 2],
expected: Result<(), LocalAddressError>,
) where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let first = api.create(Default::default());
api.set_reuseaddr(&first, set_reuseaddr[0]).expect("can set");
api.bind(&first, None, Some(PORT_1)).expect("bind succeeds");
let second = api.create(Default::default());
api.set_reuseaddr(&second, set_reuseaddr[1]).expect("can set");
let second_bind_result = api.bind(&second, None, Some(PORT_1));
assert_eq!(second_bind_result, expected.map_err(From::from));
}
#[ip_test(I)]
fn toggle_reuseaddr_bound_different_addrs<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let addrs = [1, 2].map(|i| I::get_other_ip_address(i));
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::with_ip_socket_ctx_state(
FakeDualStackIpSocketCtx::new(core::iter::once(FakeDeviceConfig {
device: FakeDeviceId,
local_ips: addrs.iter().cloned().map(SpecifiedAddr::<IpAddr>::from).collect(),
remote_ips: Default::default(),
})),
));
let mut api = ctx.tcp_api::<I>();
let first = api.create(Default::default());
api.bind(&first, Some(ZonedAddr::Unzoned(addrs[0])), Some(PORT_1)).unwrap();
let second = api.create(Default::default());
api.bind(&second, Some(ZonedAddr::Unzoned(addrs[1])), Some(PORT_1)).unwrap();
api.set_reuseaddr(&first, true).expect("can set");
api.set_reuseaddr(&first, false).expect("can un-set");
}
#[ip_test(I)]
fn unset_reuseaddr_bound_unspecified_specified<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let first = api.create(Default::default());
api.set_reuseaddr(&first, true).expect("can set");
api.bind(&first, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1)).unwrap();
let second = api.create(Default::default());
api.set_reuseaddr(&second, true).expect("can set");
api.bind(&second, None, Some(PORT_1)).unwrap();
assert_matches!(api.set_reuseaddr(&first, false), Err(SetReuseAddrError::AddrInUse));
assert_matches!(api.set_reuseaddr(&second, false), Err(SetReuseAddrError::AddrInUse));
}
#[ip_test(I)]
fn reuseaddr_allows_binding_under_connection<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
set_logger_for_test();
let mut net = new_test_net::<I>();
let server = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let server = api.create(Default::default());
api.set_reuseaddr(&server, true).expect("can set");
api.bind(&server, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1))
.expect("failed to bind the client socket");
api.listen(&server, const_unwrap_option(NonZeroUsize::new(10))).expect("can listen");
server
});
let client = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let client = api.create(Default::default());
api.connect(&client, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), PORT_1)
.expect("connect should succeed");
client
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
assert_eq!(
ctx.tcp_api().connect(
&client,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)),
PORT_1
),
Ok(())
);
});
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api();
let (_server_conn, _, _): (_, SocketAddr<_, _>, ClientBuffers) =
api.accept(&server).expect("pending connection");
assert_eq!(api.shutdown(&server, ShutdownType::Receive), Ok(false));
api.close(server);
let unbound = api.create(Default::default());
assert_eq!(
api.bind(&unbound, None, Some(PORT_1)),
Err(BindError::LocalAddressError(LocalAddressError::AddressInUse))
);
api.set_reuseaddr(&unbound, true).expect("can set");
api.bind(&unbound, None, Some(PORT_1)).expect("bind succeeds");
});
}
#[ip_test(I)]
#[test_case([true, true]; "specified specified")]
#[test_case([false, true]; "any specified")]
#[test_case([true, false]; "specified any")]
#[test_case([false, false]; "any any")]
fn set_reuseaddr_bound_allows_other_bound<I: TcpTestIpExt>(bind_specified: [bool; 2])
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let [first_addr, second_addr] =
bind_specified.map(|b| b.then_some(I::TEST_ADDRS.local_ip).map(ZonedAddr::Unzoned));
let first_bound = {
let socket = api.create(Default::default());
api.bind(&socket, first_addr, Some(PORT_1)).expect("bind succeeds");
socket
};
let second = api.create(Default::default());
assert_matches!(
api.bind(&second, second_addr, Some(PORT_1)),
Err(BindError::LocalAddressError(LocalAddressError::AddressInUse))
);
api.set_reuseaddr(&second, true).expect("can set");
assert_matches!(
api.bind(&second, second_addr, Some(PORT_1)),
Err(BindError::LocalAddressError(LocalAddressError::AddressInUse))
);
api.set_reuseaddr(&first_bound, true).expect("only socket");
api.bind(&second, second_addr, Some(PORT_1)).expect("can bind");
}
#[ip_test(I)]
fn clear_reuseaddr_listener<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let bound = {
let socket = api.create(Default::default());
api.set_reuseaddr(&socket, true).expect("can set");
api.bind(&socket, None, Some(PORT_1)).expect("bind succeeds");
socket
};
let listener = {
let socket = api.create(Default::default());
api.set_reuseaddr(&socket, true).expect("can set");
api.bind(&socket, None, Some(PORT_1)).expect("bind succeeds");
api.listen(&socket, const_unwrap_option(NonZeroUsize::new(5))).expect("can listen");
socket
};
assert_matches!(api.set_reuseaddr(&listener, false), Err(SetReuseAddrError::AddrInUse));
api.connect(&bound, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), PORT_1)
.expect("can connect");
api.set_reuseaddr(&listener, false).expect("can unset")
}
fn deliver_icmp_error<
I: TcpTestIpExt + IcmpIpExt,
CC: TcpContext<I, BC, DeviceId = FakeDeviceId>
+ TcpContext<I::OtherVersion, BC, DeviceId = FakeDeviceId>
+ CounterContext<TcpCounters<I>>
+ CounterContext<TcpCounters<I::OtherVersion>>,
BC: TcpBindingsContext + TcpBindingsContext,
>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
original_src_ip: SpecifiedAddr<I::Addr>,
original_dst_ip: SpecifiedAddr<I::Addr>,
original_body: &[u8],
err: I::ErrorCode,
) {
<TcpIpTransportContext as IpTransportContext<I, _, _>>::receive_icmp_error(
core_ctx,
bindings_ctx,
&FakeDeviceId,
Some(original_src_ip),
original_dst_ip,
original_body,
err,
);
}
#[test_case(Icmpv4DestUnreachableCode::DestNetworkUnreachable => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestHostUnreachable => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestProtocolUnreachable => ConnectionError::ProtocolUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestPortUnreachable => ConnectionError::PortUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::SourceRouteFailed => ConnectionError::SourceRouteFailed)]
#[test_case(Icmpv4DestUnreachableCode::DestNetworkUnknown => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestHostUnknown => ConnectionError::DestinationHostDown)]
#[test_case(Icmpv4DestUnreachableCode::SourceHostIsolated => ConnectionError::SourceHostIsolated)]
#[test_case(Icmpv4DestUnreachableCode::NetworkAdministrativelyProhibited => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::HostAdministrativelyProhibited => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::NetworkUnreachableForToS => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::HostUnreachableForToS => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::CommAdministrativelyProhibited => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::HostPrecedenceViolation => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::PrecedenceCutoffInEffect => ConnectionError::HostUnreachable)]
fn icmp_destination_unreachable_connect_v4(
error: Icmpv4DestUnreachableCode,
) -> ConnectionError {
icmp_destination_unreachable_connect_inner::<Ipv4>(Icmpv4ErrorCode::DestUnreachable(error))
}
#[test_case(Icmpv6DestUnreachableCode::NoRoute => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::CommAdministrativelyProhibited => ConnectionError::HostUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::BeyondScope => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::AddrUnreachable => ConnectionError::HostUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::PortUnreachable => ConnectionError::PortUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::SrcAddrFailedPolicy => ConnectionError::SourceRouteFailed)]
#[test_case(Icmpv6DestUnreachableCode::RejectRoute => ConnectionError::NetworkUnreachable)]
fn icmp_destination_unreachable_connect_v6(
error: Icmpv6DestUnreachableCode,
) -> ConnectionError {
icmp_destination_unreachable_connect_inner::<Ipv6>(Icmpv6ErrorCode::DestUnreachable(error))
}
fn icmp_destination_unreachable_connect_inner<I: TcpTestIpExt + IcmpIpExt>(
icmp_error: I::ErrorCode,
) -> ConnectionError
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<I, TcpBindingsCtx<FakeDeviceId>>
+ TcpContext<I::OtherVersion, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let connection = api.create(Default::default());
api.connect(&connection, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), PORT_1)
.expect("failed to create a connection socket");
let (core_ctx, bindings_ctx) = api.contexts();
let frames = core_ctx.ip_socket_ctx.take_frames();
let frame = assert_matches!(&frames[..], [(_meta, frame)] => frame);
deliver_icmp_error::<I, _, _>(
core_ctx,
bindings_ctx,
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
&frame[0..8],
icmp_error,
);
assert_eq!(
api.connect(&connection, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), PORT_1),
Err(ConnectError::Aborted)
);
api.get_socket_error(&connection).unwrap()
}
#[test_case(Icmpv4DestUnreachableCode::DestNetworkUnreachable => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestHostUnreachable => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestProtocolUnreachable => ConnectionError::ProtocolUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestPortUnreachable => ConnectionError::PortUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::SourceRouteFailed => ConnectionError::SourceRouteFailed)]
#[test_case(Icmpv4DestUnreachableCode::DestNetworkUnknown => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::DestHostUnknown => ConnectionError::DestinationHostDown)]
#[test_case(Icmpv4DestUnreachableCode::SourceHostIsolated => ConnectionError::SourceHostIsolated)]
#[test_case(Icmpv4DestUnreachableCode::NetworkAdministrativelyProhibited => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::HostAdministrativelyProhibited => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::NetworkUnreachableForToS => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::HostUnreachableForToS => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::CommAdministrativelyProhibited => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::HostPrecedenceViolation => ConnectionError::HostUnreachable)]
#[test_case(Icmpv4DestUnreachableCode::PrecedenceCutoffInEffect => ConnectionError::HostUnreachable)]
fn icmp_destination_unreachable_established_v4(
error: Icmpv4DestUnreachableCode,
) -> ConnectionError {
icmp_destination_unreachable_established_inner::<Ipv4>(Icmpv4ErrorCode::DestUnreachable(
error,
))
}
#[test_case(Icmpv6DestUnreachableCode::NoRoute => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::CommAdministrativelyProhibited => ConnectionError::HostUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::BeyondScope => ConnectionError::NetworkUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::AddrUnreachable => ConnectionError::HostUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::PortUnreachable => ConnectionError::PortUnreachable)]
#[test_case(Icmpv6DestUnreachableCode::SrcAddrFailedPolicy => ConnectionError::SourceRouteFailed)]
#[test_case(Icmpv6DestUnreachableCode::RejectRoute => ConnectionError::NetworkUnreachable)]
fn icmp_destination_unreachable_established_v6(
error: Icmpv6DestUnreachableCode,
) -> ConnectionError {
icmp_destination_unreachable_established_inner::<Ipv6>(Icmpv6ErrorCode::DestUnreachable(
error,
))
}
fn icmp_destination_unreachable_established_inner<I: TcpTestIpExt + IcmpIpExt>(
icmp_error: I::ErrorCode,
) -> ConnectionError
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
> + TcpContext<I::OtherVersion, TcpBindingsCtx<FakeDeviceId>>,
{
let (mut net, local, local_snd_end, _remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
local_snd_end.lock().extend_from_slice(b"Hello");
net.with_context(LOCAL, |ctx| {
ctx.tcp_api().do_send(&local);
});
net.collect_frames();
let original_body = assert_matches!(
&net.iter_pending_frames().collect::<Vec<_>>()[..],
[InstantAndData(_instant, PendingFrameData {
dst_context: _,
meta: _,
frame,
})] => {
frame.clone()
});
net.with_context(LOCAL, |ctx| {
let TcpCtx { core_ctx, bindings_ctx } = ctx;
deliver_icmp_error::<I, _, _>(
core_ctx,
bindings_ctx,
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
&original_body[..],
icmp_error,
);
let error = assert_matches!(
ctx.tcp_api().get_socket_error(&local),
Some(error) => error
);
assert_matches!(
&local.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::Established(_),
..
}
);
}
);
error
})
}
#[ip_test(I)]
fn icmp_destination_unreachable_listener<I: TcpTestIpExt + IcmpIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<I, TcpBindingsCtx<FakeDeviceId>>
+ TcpContext<I::OtherVersion, TcpBindingsCtx<FakeDeviceId>>
+ CounterContext<TcpCounters<I>>,
{
let mut net = new_test_net::<I>();
let backlog = NonZeroUsize::new(1).unwrap();
let server = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let server = api.create(Default::default());
api.bind(&server, None, Some(PORT_1)).expect("failed to bind the server socket");
api.listen(&server, backlog).expect("can listen");
server
});
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let conn = api.create(Default::default());
api.connect(&conn, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), PORT_1)
.expect("failed to connect");
});
assert!(!net.step().is_idle());
net.collect_frames();
let original_body = assert_matches!(
&net.iter_pending_frames().collect::<Vec<_>>()[..],
[InstantAndData(_instant, PendingFrameData {
dst_context: _,
meta: _,
frame,
})] => {
frame.clone()
});
let icmp_error = I::map_ip(
(),
|()| Icmpv4ErrorCode::DestUnreachable(Icmpv4DestUnreachableCode::DestPortUnreachable),
|()| Icmpv6ErrorCode::DestUnreachable(Icmpv6DestUnreachableCode::PortUnreachable),
);
net.with_context(REMOTE, |TcpCtx { core_ctx, bindings_ctx }| {
let in_queue = {
let state = server.get();
let accept_queue = assert_matches!(
&state.deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(Listener { accept_queue, .. }),
..
))) => accept_queue
);
assert_eq!(accept_queue.len(), 1);
accept_queue.collect_pending().first().unwrap().downgrade()
};
deliver_icmp_error::<I, _, _>(
core_ctx,
bindings_ctx,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.local_ip,
&original_body[..],
icmp_error,
);
{
let state = server.get();
let queue_len = assert_matches!(
&state.deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(Listener { accept_queue, .. }),
..
))) => accept_queue.len()
);
assert_eq!(queue_len, 0);
}
assert_eq!(in_queue.upgrade(), None);
});
}
#[ip_test(I)]
fn time_wait_reuse<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
const CLIENT_PORT: NonZeroU16 = const_unwrap_option(NonZeroU16::new(2));
const SERVER_PORT: NonZeroU16 = const_unwrap_option(NonZeroU16::new(1));
let (mut net, local, _local_snd_end, remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: Some(CLIENT_PORT),
server_port: SERVER_PORT,
client_reuse_addr: true,
send_test_data: false,
},
0,
0.0,
);
let listener = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let listener = api.create(Default::default());
api.set_reuseaddr(&listener, true).expect("can set");
api.bind(
&listener,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)),
Some(CLIENT_PORT),
)
.expect("failed to bind");
api.listen(&listener, NonZeroUsize::new(1).unwrap()).expect("failed to listen");
listener
});
let extra_conn = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let extra_conn = api.create(Default::default());
api.connect(&extra_conn, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), CLIENT_PORT)
.expect("failed to connect");
extra_conn
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
assert_eq!(
ctx.tcp_api().connect(
&extra_conn,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)),
CLIENT_PORT,
),
Ok(())
);
});
let weak_local = local.downgrade();
net.with_context(LOCAL, |ctx| {
ctx.tcp_api().close(local);
});
assert!(!net.step().is_idle());
assert!(!net.step().is_idle());
net.with_context(REMOTE, |ctx| {
ctx.tcp_api().close(remote);
});
assert!(!net.step().is_idle());
assert!(!net.step().is_idle());
let (tw_last_seq, tw_last_ack, tw_expiry) = {
assert_matches!(
&weak_local.upgrade().unwrap().get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::TimeWait(TimeWait {
last_seq,
last_ack,
expiry,
..
}), ..
} => (*last_seq, *last_ack, *expiry)
)
}
)
};
let conn = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let conn = api.create(Default::default());
api.connect(&conn, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), CLIENT_PORT)
.expect("failed to connect");
conn
});
while net.next_step() != Some(tw_expiry) {
assert!(!net.step().is_idle());
}
assert_matches!(
&conn.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(conn, &I::converter());
assert_matches!(
conn,
Connection {
state: State::Closed(Closed { reason: Some(ConnectionError::TimedOut) }),
..
}
);
});
net.with_context(LOCAL, |ctx| {
let _accepted =
ctx.tcp_api().accept(&listener).expect("failed to accept a new connection");
});
let conn = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.bind(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), Some(SERVER_PORT))
.expect("failed to bind");
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), CLIENT_PORT)
.expect("failed to connect");
socket
});
net.collect_frames();
assert_matches!(
&net.iter_pending_frames().collect::<Vec<_>>()[..],
[InstantAndData(_instant, PendingFrameData {
dst_context: _,
meta,
frame,
})] => {
let mut buffer = Buf::new(frame, ..);
let iss = match I::VERSION {
IpVersion::V4 => {
let meta = assert_matches!(meta, DualStackSendIpPacketMeta::V4(meta) => meta);
let parsed = buffer.parse_with::<_, TcpSegment<_>>(
TcpParseArgs::new(*meta.src_ip, *meta.dst_ip)
).expect("failed to parse");
assert!(parsed.syn());
SeqNum::new(parsed.seq_num())
}
IpVersion::V6 => {
let meta = assert_matches!(meta, DualStackSendIpPacketMeta::V6(meta) => meta);
let parsed = buffer.parse_with::<_, TcpSegment<_>>(
TcpParseArgs::new(*meta.src_ip, *meta.dst_ip)
).expect("failed to parse");
assert!(parsed.syn());
SeqNum::new(parsed.seq_num())
}
};
assert!(iss.after(tw_last_ack) && iss.before(tw_last_seq));
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
assert_eq!(
ctx.tcp_api().connect(
&conn,
Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)),
CLIENT_PORT
),
Ok(())
);
});
}
#[ip_test(I)]
fn conn_addr_not_available<I: TcpTestIpExt + IcmpIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
set_logger_for_test();
let (mut net, _local, _local_snd_end, _remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: Some(PORT_1),
server_port: PORT_1,
client_reuse_addr: true,
send_test_data: false,
},
0,
0.0,
);
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
api.set_reuseaddr(&socket, true).expect("can set");
api.bind(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.local_ip)), Some(PORT_1))
.expect("failed to bind");
assert_eq!(
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), PORT_1),
Err(ConnectError::ConnectionExists),
)
});
}
#[test_case::test_matrix(
[None, Some(ZonedAddr::Unzoned((*Ipv4::TEST_ADDRS.remote_ip).to_ipv6_mapped()))],
[None, Some(PORT_1)],
[true, false]
)]
fn dual_stack_connect(
server_bind_ip: Option<ZonedAddr<SpecifiedAddr<Ipv6Addr>, FakeDeviceId>>,
server_bind_port: Option<NonZeroU16>,
bind_client: bool,
) {
set_logger_for_test();
let mut net = new_test_net::<Ipv4>();
let backlog = NonZeroUsize::new(1).unwrap();
let (server, listen_port) = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let server = api.create(Default::default());
api.bind(&server, server_bind_ip, server_bind_port)
.expect("failed to bind the server socket");
api.listen(&server, backlog).expect("can listen");
let port = assert_matches!(
api.get_info(&server),
SocketInfo::Bound(info) => info.port
);
(server, port)
});
let client_ends = WriteBackClientBuffers::default();
let client = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(ProvidedBuffers::Buffers(client_ends.clone()));
if bind_client {
api.bind(&socket, None, None).expect("failed to bind");
}
api.connect(
&socket,
Some(ZonedAddr::Unzoned((*Ipv4::TEST_ADDRS.remote_ip).to_ipv6_mapped())),
listen_port,
)
.expect("failed to connect");
socket
});
net.run_until_idle();
let (accepted, addr, accepted_ends) = net
.with_context(REMOTE, |ctx| ctx.tcp_api().accept(&server).expect("failed to accept"));
assert_eq!(addr.ip, ZonedAddr::Unzoned((*Ipv4::TEST_ADDRS.local_ip).to_ipv6_mapped()));
let ClientBuffers { send: client_snd_end, receive: client_rcv_end } =
client_ends.0.as_ref().lock().take().unwrap();
let ClientBuffers { send: accepted_snd_end, receive: accepted_rcv_end } = accepted_ends;
for snd_end in [client_snd_end, accepted_snd_end] {
snd_end.lock().extend_from_slice(b"Hello");
}
net.with_context(LOCAL, |ctx| ctx.tcp_api().do_send(&client));
net.with_context(REMOTE, |ctx| ctx.tcp_api().do_send(&accepted));
net.run_until_idle();
for rcv_end in [client_rcv_end, accepted_rcv_end] {
assert_eq!(
rcv_end.lock().read_with(|avail| {
let avail = avail.concat();
assert_eq!(avail, b"Hello");
avail.len()
}),
5
);
}
let info = assert_matches!(
net.with_context(LOCAL, |ctx| ctx.tcp_api().get_info(&client)),
SocketInfo::Connection(info) => info
);
let (local_ip, remote_ip, port) = assert_matches!(
info,
ConnectionInfo {
local_addr: SocketAddr { ip: local_ip, port: _ },
remote_addr: SocketAddr { ip: remote_ip, port },
device: _
} => (local_ip.addr(), remote_ip.addr(), port)
);
assert_eq!(remote_ip, Ipv4::TEST_ADDRS.remote_ip.to_ipv6_mapped());
assert_matches!(local_ip.to_ipv4_mapped(), Some(_));
assert_eq!(port, listen_port);
}
#[test]
fn ipv6_dual_stack_enabled() {
set_logger_for_test();
let mut net = new_test_net::<Ipv4>();
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<Ipv6>();
let socket = api.create(Default::default());
assert_eq!(api.dual_stack_enabled(&socket), Ok(true));
api.set_dual_stack_enabled(&socket, false).expect("failed to disable dual stack");
assert_eq!(api.dual_stack_enabled(&socket), Ok(false));
assert_eq!(
api.bind(
&socket,
Some(ZonedAddr::Unzoned((*Ipv4::TEST_ADDRS.local_ip).to_ipv6_mapped())),
Some(PORT_1),
),
Err(BindError::LocalAddressError(LocalAddressError::CannotBindToAddress))
);
assert_eq!(
api.connect(
&socket,
Some(ZonedAddr::Unzoned((*Ipv4::TEST_ADDRS.remote_ip).to_ipv6_mapped())),
PORT_1,
),
Err(ConnectError::NoRoute)
);
});
}
#[test]
fn ipv4_dual_stack_enabled() {
set_logger_for_test();
let mut net = new_test_net::<Ipv4>();
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<Ipv4>();
let socket = api.create(Default::default());
assert_eq!(api.dual_stack_enabled(&socket), Err(NotDualStackCapableError));
assert_eq!(
api.set_dual_stack_enabled(&socket, true),
Err(SetDualStackEnabledError::NotCapable)
);
});
}
#[ip_test(I)]
fn closed_not_in_demux<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
let (mut net, local, _local_snd_end, remote) = bind_listen_connect_accept_inner::<I>(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
for ctx_name in [LOCAL, REMOTE] {
net.with_context(ctx_name, |CtxPair { core_ctx, bindings_ctx: _ }| {
TcpDemuxContext::<I, _, _>::with_demux(core_ctx, |DemuxState { socketmap }| {
assert_eq!(socketmap.len(), 1);
})
});
}
for (ctx_name, socket) in [(LOCAL, &local), (REMOTE, &remote)] {
net.with_context(ctx_name, |ctx| {
assert_eq!(ctx.tcp_api().shutdown(socket, ShutdownType::SendAndReceive), Ok(true));
});
}
net.run_until_idle();
for ctx_name in [LOCAL, REMOTE] {
net.with_context(ctx_name, |CtxPair { core_ctx, bindings_ctx: _ }| {
TcpDemuxContext::<I, _, _>::with_demux(core_ctx, |DemuxState { socketmap }| {
assert_eq!(socketmap.len(), 0);
})
});
}
}
#[ip_test(I)]
fn tcp_accept_queue_clean_up_closed<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut net = new_test_net::<I>();
let backlog = NonZeroUsize::new(1).unwrap();
let server_port = NonZeroU16::new(1024).unwrap();
let server = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let server = api.create(Default::default());
api.bind(&server, None, Some(server_port)).expect("failed to bind the server socket");
api.listen(&server, backlog).expect("can listen");
server
});
let client = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(ProvidedBuffers::Buffers(WriteBackClientBuffers::default()));
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), server_port)
.expect("failed to connect");
socket
});
assert!(!net.step().is_idle());
assert_matches!(
&server.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(Listener {
accept_queue,
..
}), ..))) => {
assert_eq!(accept_queue.ready_len(), 0);
assert_eq!(accept_queue.pending_len(), 1);
}
);
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
api.close(client);
});
net.run_until_idle();
assert_matches!(
&server.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Listener((
MaybeListener::Listener(Listener {
accept_queue,
..
}), ..))) => {
assert_eq!(accept_queue.ready_len(), 0);
assert_eq!(accept_queue.pending_len(), 0);
}
);
net.with_context(REMOTE, |ctx| {
ctx.core_ctx.with_all_sockets_mut(|all_sockets| {
assert_eq!(all_sockets.keys().collect::<Vec<_>>(), [&server]);
})
})
}
#[ip_test(I)]
#[test_case::test_matrix(
[MarkDomain::Mark1, MarkDomain::Mark2],
[None, Some(0), Some(1)]
)]
fn tcp_socket_marks<I: TcpTestIpExt>(domain: MarkDomain, mark: Option<u32>)
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>:
TcpContext<I, TcpBindingsCtx<FakeDeviceId>>,
{
let mut ctx = TcpCtx::with_core_ctx(TcpCoreCtx::new::<I>(
I::TEST_ADDRS.local_ip,
I::TEST_ADDRS.remote_ip,
I::TEST_ADDRS.subnet.prefix(),
));
let mut api = ctx.tcp_api::<I>();
let socket = api.create(Default::default());
assert_eq!(api.get_mark(&socket, domain), Mark(None));
let mark = Mark(mark);
api.set_mark(&socket, domain, mark);
assert_eq!(api.get_mark(&socket, domain), mark);
}
#[ip_test(I)]
fn tcp_marks_for_accepted_sockets<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
let mut net = new_test_net::<I>();
let backlog = NonZeroUsize::new(1).unwrap();
let server_port = NonZeroU16::new(1234).unwrap();
let server = net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
let server = api.create(Default::default());
api.set_mark(&server, MarkDomain::Mark1, Mark(Some(1)));
api.bind(&server, None, Some(server_port)).expect("failed to bind the server socket");
api.listen(&server, backlog).expect("can listen");
server
});
let client_ends = WriteBackClientBuffers::default();
let _client = net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
let socket = api.create(ProvidedBuffers::Buffers(client_ends.clone()));
api.connect(&socket, Some(ZonedAddr::Unzoned(I::TEST_ADDRS.remote_ip)), server_port)
.expect("failed to connect");
socket
});
net.run_until_idle();
net.with_context(REMOTE, |ctx| {
let (accepted, _addr, _accepted_ends) =
ctx.tcp_api::<I>().accept(&server).expect("failed to accept");
assert_eq!(ctx.tcp_api::<I>().get_mark(&accepted, MarkDomain::Mark1), Mark(Some(1)));
});
}
#[ip_test(I)]
fn do_send_can_remove_sockets_from_demux_state<I: TcpTestIpExt>()
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
let (mut net, client, _client_snd_end, accepted) = bind_listen_connect_accept_inner(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
net.with_context(LOCAL, |ctx| {
let mut api = ctx.tcp_api::<I>();
assert_eq!(api.shutdown(&client, ShutdownType::Send), Ok(true));
});
assert!(!net.step().is_idle());
assert!(!net.step().is_idle());
net.with_context(REMOTE, |ctx| {
let mut api = ctx.tcp_api::<I>();
assert_eq!(api.shutdown(&accepted, ShutdownType::Send), Ok(true));
});
assert!(!net.step().is_idle());
assert!(!net.step().is_idle());
net.with_context(LOCAL, |CtxPair { core_ctx, bindings_ctx: _ }| {
TcpDemuxContext::<I, _, _>::with_demux(core_ctx, |DemuxState { socketmap }| {
assert_eq!(socketmap.len(), 1);
})
});
assert_matches!(
&client.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(
conn,
&I::converter()
);
assert_matches!(
conn,
Connection {
state: State::TimeWait(_),
..
}
);
}
);
net.with_context(LOCAL, |ctx| {
ctx.with_fake_timer_ctx_mut(|ctx| {
ctx.instant.time =
ctx.instant.time.checked_add(Duration::from_secs(120 * 60)).unwrap()
});
let mut api = ctx.tcp_api::<I>();
api.do_send(&client);
});
assert_matches!(
&client.get().deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
let (conn, _addr) = assert_this_stack_conn::<I, _, TcpCoreCtx<_, _>>(
conn,
&I::converter()
);
assert_matches!(
conn,
Connection {
state: State::Closed(_),
..
}
);
}
);
net.with_context(LOCAL, |CtxPair { core_ctx, bindings_ctx: _ }| {
TcpDemuxContext::<I, _, _>::with_demux(core_ctx, |DemuxState { socketmap }| {
assert_eq!(socketmap.len(), 0);
})
});
}
#[ip_test(I)]
#[test_case(true; "server read over mss")]
#[test_case(false; "server read under mss")]
fn tcp_data_dequeue_sends_window_update<I: TcpTestIpExt>(server_read_over_mss: bool)
where
TcpCoreCtx<FakeDeviceId, TcpBindingsCtx<FakeDeviceId>>: TcpContext<
I,
TcpBindingsCtx<FakeDeviceId>,
SingleStackConverter = I::SingleStackConverter,
DualStackConverter = I::DualStackConverter,
>,
{
const EXTRA_DATA_AMOUNT: usize = 128;
set_logger_for_test();
let (mut net, client, client_snd_end, accepted) = bind_listen_connect_accept_inner(
I::UNSPECIFIED_ADDRESS,
BindConfig {
client_port: None,
server_port: PORT_1,
client_reuse_addr: false,
send_test_data: false,
},
0,
0.0,
);
let accepted_rcv_bufsize = net
.with_context(REMOTE, |ctx| ctx.tcp_api::<I>().receive_buffer_size(&accepted).unwrap());
client_snd_end.lock().extend(core::iter::repeat(0xAB).take(accepted_rcv_bufsize));
net.with_context(LOCAL, |ctx| {
ctx.tcp_api().do_send(&client);
});
net.run_until_idle();
client_snd_end.lock().extend(core::iter::repeat(0xAB).take(EXTRA_DATA_AMOUNT));
net.with_context(LOCAL, |ctx| {
ctx.tcp_api().do_send(&client);
});
let _ = net.step_deliver_frames();
let send_buf_len = net
.with_context(LOCAL, |ctx| {
ctx.tcp_api::<I>().with_send_buffer(&client, |buf| {
let BufferLimits { len, capacity: _ } = buf.limits();
len
})
})
.unwrap();
assert_eq!(send_buf_len, EXTRA_DATA_AMOUNT);
if server_read_over_mss {
let nread = net
.with_context(REMOTE, |ctx| {
ctx.tcp_api::<I>().with_receive_buffer(&accepted, |buf| {
buf.lock()
.read_with(|readable| readable.into_iter().map(|buf| buf.len()).sum())
})
})
.unwrap();
assert_eq!(nread, accepted_rcv_bufsize);
net.with_context(REMOTE, |ctx| ctx.tcp_api::<I>().on_receive_buffer_read(&accepted));
let (server_snd_max, server_acknum) = {
let socket = accepted.get();
let state = assert_matches!(
&socket.deref().socket_state,
TcpSocketStateInner::Bound(BoundSocketState::Connected { conn, .. }) => {
assert_matches!(I::get_state(conn), State::Established(e) => e)
}
);
(state.snd.max, state.rcv.nxt())
};
assert_eq!(
net.step_deliver_frames_with(|_, meta, frame| {
let mut buffer = Buf::new(frame.clone(), ..);
let (packet_seq, packet_ack, window_size, body_len) = match I::VERSION {
IpVersion::V4 => {
let meta =
assert_matches!(&meta, DualStackSendIpPacketMeta::V4(v4) => v4);
assert_eq!(*meta.src_ip, Ipv4::TEST_ADDRS.remote_ip.into_addr());
assert_eq!(*meta.dst_ip, Ipv4::TEST_ADDRS.local_ip.into_addr());
let parsed = buffer
.parse_with::<_, TcpSegment<_>>(TcpParseArgs::new(
*meta.src_ip,
*meta.dst_ip,
))
.expect("failed to parse");
(
parsed.seq_num(),
parsed.ack_num().unwrap(),
parsed.window_size(),
parsed.body().len(),
)
}
IpVersion::V6 => {
let meta =
assert_matches!(&meta, DualStackSendIpPacketMeta::V6(v6) => v6);
assert_eq!(*meta.src_ip, Ipv6::TEST_ADDRS.remote_ip.into_addr());
assert_eq!(*meta.dst_ip, Ipv6::TEST_ADDRS.local_ip.into_addr());
let parsed = buffer
.parse_with::<_, TcpSegment<_>>(TcpParseArgs::new(
*meta.src_ip,
*meta.dst_ip,
))
.expect("failed to parse");
(
parsed.seq_num(),
parsed.ack_num().unwrap(),
parsed.window_size(),
parsed.body().len(),
)
}
};
assert_eq!(packet_seq, u32::from(server_snd_max));
assert_eq!(packet_ack, u32::from(server_acknum));
assert_eq!(window_size, 65535);
assert_eq!(body_len, 0);
Some((meta, frame))
})
.frames_sent,
1
);
assert_eq!(
net.step_deliver_frames_with(|_, meta, frame| {
let mut buffer = Buf::new(frame.clone(), ..);
let body_len = match I::VERSION {
IpVersion::V4 => {
let meta =
assert_matches!(&meta, DualStackSendIpPacketMeta::V4(v4) => v4);
assert_eq!(*meta.src_ip, Ipv4::TEST_ADDRS.local_ip.into_addr());
assert_eq!(*meta.dst_ip, Ipv4::TEST_ADDRS.remote_ip.into_addr());
let parsed = buffer
.parse_with::<_, TcpSegment<_>>(TcpParseArgs::new(
*meta.src_ip,
*meta.dst_ip,
))
.expect("failed to parse");
parsed.body().len()
}
IpVersion::V6 => {
let meta =
assert_matches!(&meta, DualStackSendIpPacketMeta::V6(v6) => v6);
assert_eq!(*meta.src_ip, Ipv6::TEST_ADDRS.local_ip.into_addr());
assert_eq!(*meta.dst_ip, Ipv6::TEST_ADDRS.remote_ip.into_addr());
let parsed = buffer
.parse_with::<_, TcpSegment<_>>(TcpParseArgs::new(
*meta.src_ip,
*meta.dst_ip,
))
.expect("failed to parse");
parsed.body().len()
}
};
assert_eq!(body_len, EXTRA_DATA_AMOUNT);
Some((meta, frame))
})
.frames_sent,
1
);
assert_eq!(
net.step_deliver_frames_with(|_, meta, frame| {
let mut buffer = Buf::new(frame.clone(), ..);
let (packet_seq, packet_ack, body_len) = match I::VERSION {
IpVersion::V4 => {
let meta =
assert_matches!(&meta, DualStackSendIpPacketMeta::V4(v4) => v4);
assert_eq!(*meta.src_ip, Ipv4::TEST_ADDRS.remote_ip.into_addr());
assert_eq!(*meta.dst_ip, Ipv4::TEST_ADDRS.local_ip.into_addr());
let parsed = buffer
.parse_with::<_, TcpSegment<_>>(TcpParseArgs::new(
*meta.src_ip,
*meta.dst_ip,
))
.expect("failed to parse");
(parsed.seq_num(), parsed.ack_num().unwrap(), parsed.body().len())
}
IpVersion::V6 => {
let meta =
assert_matches!(&meta, DualStackSendIpPacketMeta::V6(v6) => v6);
assert_eq!(*meta.src_ip, Ipv6::TEST_ADDRS.remote_ip.into_addr());
assert_eq!(*meta.dst_ip, Ipv6::TEST_ADDRS.local_ip.into_addr());
let parsed = buffer
.parse_with::<_, TcpSegment<_>>(TcpParseArgs::new(
*meta.src_ip,
*meta.dst_ip,
))
.expect("failed to parse");
(parsed.seq_num(), parsed.ack_num().unwrap(), parsed.body().len())
}
};
assert_eq!(packet_seq, u32::from(server_snd_max));
assert_eq!(
packet_ack,
u32::from(server_acknum) + u32::try_from(EXTRA_DATA_AMOUNT).unwrap()
);
assert_eq!(body_len, 0);
Some((meta, frame))
})
.frames_sent,
1
);
let send_buf_len = net
.with_context(LOCAL, |ctx| {
ctx.tcp_api::<I>().with_send_buffer(&client, |buf| {
let BufferLimits { len, capacity: _ } = buf.limits();
len
})
})
.unwrap();
assert_eq!(send_buf_len, 0);
} else {
let nread = net
.with_context(REMOTE, |ctx| {
ctx.tcp_api::<I>()
.with_receive_buffer(&accepted, |buf| buf.lock().read_with(|_readable| 1))
})
.unwrap();
assert_eq!(nread, 1);
net.with_context(REMOTE, |ctx| ctx.tcp_api::<I>().on_receive_buffer_read(&accepted));
assert_eq!(net.step_deliver_frames().frames_sent, 0);
let send_buf_len = net
.with_context(LOCAL, |ctx| {
ctx.tcp_api::<I>().with_send_buffer(&client, |buf| {
let BufferLimits { len, capacity: _ } = buf.limits();
len
})
})
.unwrap();
assert_eq!(send_buf_len, EXTRA_DATA_AMOUNT);
}
}
}