1use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersion, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17 AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19use thiserror::Error;
20
21use crate::LocalAddressError;
22use crate::data_structures::socketmap::{
23 Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
24};
25use crate::device::{
26 DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
27};
28use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
29use crate::ip::BroadcastIpExt;
30use crate::socket::SocketCookie;
31use crate::socket::address::{
32 AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
33};
34use packet_formats::ip::{IpProto, Ipv4Proto, Ipv6Proto};
35
36pub trait DualStackIpExt: Ip {
39 type OtherVersion: DualStackIpExt<OtherVersion = Self>;
41}
42
43impl DualStackIpExt for Ipv4 {
44 type OtherVersion = Ipv6;
45}
46
47impl DualStackIpExt for Ipv6 {
48 type OtherVersion = Ipv4;
49}
50
51pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
53 this_stack: <T as GenericOverIp<I>>::Type,
54 other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
55 _marker: IpVersionMarker<I>,
56}
57
58impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
59 pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
61 where
62 T: GenericOverIp<I, Type = T>,
63 {
64 Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
65 }
66
67 pub fn into_inner(
69 self,
70 ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
71 let Self { this_stack, other_stack, _marker } = self;
72 (this_stack, other_stack)
73 }
74
75 pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
77 self.this_stack
78 }
79
80 pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
82 &self.this_stack
83 }
84
85 pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
87 self.other_stack
88 }
89
90 pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
92 &self.other_stack
93 }
94
95 pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
97 let Self { this_stack, other_stack, _marker } = self;
98 DualStackTuple {
99 this_stack: other_stack,
100 other_stack: this_stack,
101 _marker: IpVersionMarker::new(),
102 }
103 }
104
105 pub fn cast<X>(self) -> DualStackTuple<X, T>
114 where
115 X: DualStackIpExt,
116 T: GenericOverIp<X>
117 + GenericOverIp<X::OtherVersion>
118 + GenericOverIp<Ipv4>
119 + GenericOverIp<Ipv6>,
120 {
121 I::map_ip_in(
122 self,
123 |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
124 |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
125 )
126 }
127}
128
129impl<
130 I: DualStackIpExt,
131 NewIp: DualStackIpExt,
132 T: GenericOverIp<NewIp>
133 + GenericOverIp<NewIp::OtherVersion>
134 + GenericOverIp<I>
135 + GenericOverIp<I::OtherVersion>,
136> GenericOverIp<NewIp> for DualStackTuple<I, T>
137{
138 type Type = DualStackTuple<NewIp, T>;
139}
140
141pub trait SocketIpExt: Ip {
143 const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
145 SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
148 };
149}
150
151impl<I: Ip> SocketIpExt for I {}
152
153#[cfg(test)]
154mod socket_ip_ext_test {
155 use super::*;
156 use ip_test_macro::ip_test;
157
158 #[ip_test(I)]
159 fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
160 let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
165 .expect("loopback address should be a valid SocketIpAddr");
166 }
167}
168
169#[derive(Copy, Clone, Debug, PartialEq, Eq, GenericOverIp)]
171#[generic_over_ip()]
172pub enum EitherIpProto {
173 V4(Ipv4Proto),
175 V6(Ipv6Proto),
177}
178
179impl EitherIpProto {
180 pub fn ip_version(&self) -> IpVersion {
182 match self {
183 Self::V4(_) => IpVersion::V4,
184 Self::V6(_) => IpVersion::V6,
185 }
186 }
187
188 pub fn ip_proto(&self) -> Option<IpProto> {
190 match self {
191 Self::V4(p) => match p {
192 Ipv4Proto::Proto(proto) => Some(*proto),
193 _ => None,
194 },
195 Self::V6(p) => match p {
196 Ipv6Proto::Proto(proto) => Some(*proto),
197 _ => None,
198 },
199 }
200 }
201
202 pub fn u8_value(&self) -> u8 {
204 match self {
205 Self::V4(p) => (*p).into(),
206 Self::V6(p) => (*p).into(),
207 }
208 }
209}
210
211#[derive(Clone, Debug)]
213#[cfg_attr(any(test, feature = "testutils"), derive(PartialEq, Eq))]
214pub struct SocketInfo {
215 pub proto: EitherIpProto,
217 pub cookie: SocketCookie,
219}
220
221#[derive(Debug, PartialEq, Eq)]
229pub enum EitherStack<T, O> {
230 ThisStack(T),
232 OtherStack(O),
234}
235
236impl<T, O> Clone for EitherStack<T, O>
237where
238 T: Clone,
239 O: Clone,
240{
241 #[cfg_attr(feature = "instrumented", track_caller)]
242 fn clone(&self) -> Self {
243 match self {
244 Self::ThisStack(t) => Self::ThisStack(t.clone()),
245 Self::OtherStack(t) => Self::OtherStack(t.clone()),
246 }
247 }
248}
249
250#[derive(Debug)]
268#[allow(missing_docs)]
269pub enum MaybeDualStack<DS, NDS> {
270 DualStack(DS),
271 NotDualStack(NDS),
272}
273
274impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
277 for MaybeDualStack<DS, NDS>
278{
279 type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
280}
281
282#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
284#[generic_over_ip()]
285pub enum SetDualStackEnabledError {
286 #[error("a socket can only have dual stack enabled or disabled while unbound")]
288 SocketIsBound,
289 #[error(transparent)]
291 NotCapable(#[from] NotDualStackCapableError),
292}
293
294#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
297#[generic_over_ip()]
298#[error("socket's protocol is not dual-stack capable")]
299pub struct NotDualStackCapableError;
300
301#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
303pub struct Shutdown {
304 pub send: bool,
308 pub receive: bool,
312}
313
314#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
316#[generic_over_ip()]
317pub enum ShutdownType {
318 Send,
320 Receive,
322 SendAndReceive,
324}
325
326impl ShutdownType {
327 pub fn to_send_receive(&self) -> (bool, bool) {
329 match self {
330 Self::Send => (true, false),
331 Self::Receive => (false, true),
332 Self::SendAndReceive => (true, true),
333 }
334 }
335
336 pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
338 match (send, receive) {
339 (true, false) => Some(Self::Send),
340 (false, true) => Some(Self::Receive),
341 (true, true) => Some(Self::SendAndReceive),
342 (false, false) => None,
343 }
344 }
345}
346
347pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
349 fn must_have_zone(&self) -> bool
355 where
356 Self: Copy,
357 {
358 self.try_into_null_zoned().is_some()
359 }
360
361 fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
365 if self.get().is_loopback() {
366 return None;
367 }
368 AddrAndZone::new(self, ())
369 }
370}
371
372impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
373
374pub trait SocketZonedAddrExt<W, A, D> {
376 fn resolve_addr_with_device(
384 self,
385 device: Option<D::Weak>,
386 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
387 where
388 D: StrongDeviceIdentifier;
389}
390
391impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
392where
393 W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
394 A: IpAddress,
395{
396 fn resolve_addr_with_device(
397 self,
398 device: Option<D::Weak>,
399 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
400 where
401 D: StrongDeviceIdentifier,
402 {
403 let (addr, zone) = self.into_addr_zone();
404 let device = match (zone, device) {
405 (Some(zone), Some(device)) => {
406 if device != zone {
407 return Err(ZonedAddressError::DeviceZoneMismatch);
408 }
409 Some(EitherDeviceId::Strong(zone))
410 }
411 (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
412 (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
413 (None, None) => {
414 if addr.as_ref().must_have_zone() {
415 return Err(ZonedAddressError::RequiredZoneNotProvided);
416 } else {
417 None
418 }
419 }
420 };
421 Ok((addr, device))
422 }
423}
424
425pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
431 pub local_ip: Option<&'a SpecifiedAddr<A>>,
433 pub remote_ip: Option<&'a SpecifiedAddr<A>>,
435 pub old_device: Option<&'a D>,
437}
438
439impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
440 pub fn check_update<N>(
443 self,
444 new_device: Option<&N>,
445 ) -> Result<(), SocketDeviceUpdateNotAllowedError>
446 where
447 D: PartialEq<N>,
448 {
449 let Self { local_ip, remote_ip, old_device } = self;
450 let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
451 || remote_ip.is_some_and(|a| a.must_have_zone());
452
453 if !must_have_zone {
454 return Ok(());
455 }
456
457 let old_device = old_device.unwrap_or_else(|| {
458 panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
459 });
460
461 if new_device.is_some_and(|new_device| old_device == new_device) {
462 Ok(())
463 } else {
464 Err(SocketDeviceUpdateNotAllowedError)
465 }
466 }
467}
468
469pub struct SocketDeviceUpdateNotAllowedError;
471
472pub trait SocketMapAddrSpec {
477 type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
479 type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
481}
482
483pub struct ListenerAddrInfo {
485 pub has_device: bool,
487 pub specified_addr: bool,
490}
491
492impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
493 pub(crate) fn info(&self) -> ListenerAddrInfo {
494 let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
495 ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
496 }
497}
498
499pub trait SocketMapStateSpec {
501 type AddrVecTag: Eq + Copy + Debug + 'static;
506
507 fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
509
510 fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
512
513 type ListenerId: Clone + Debug;
515 type ConnId: Clone + Debug;
517
518 type ListenerSharingState: Clone + Debug;
521
522 type ConnSharingState: Clone + Debug;
525
526 type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
528 + Debug;
529
530 type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
532 + Debug;
533}
534
535#[derive(Copy, Clone, Debug, Eq, PartialEq)]
538pub struct IncompatibleError;
539
540pub trait Inserter<T> {
542 fn insert(self, item: T);
547}
548
549impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
550 fn insert(self, item: T) {
551 self.extend([item])
552 }
553}
554
555impl<T> Inserter<T> for Never {
556 fn insert(self, _: T) {
557 match self {}
558 }
559}
560
561pub trait SocketMapAddrStateSpec {
563 type Id;
565
566 type SharingState;
573
574 type Inserter<'a>: Inserter<Self::Id> + 'a
576 where
577 Self: 'a,
578 Self::Id: 'a;
579
580 fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
583
584 fn contains_id(&self, id: &Self::Id) -> bool;
586
587 fn try_get_inserter<'a, 'b>(
595 &'b mut self,
596 new_sharing_state: &'a Self::SharingState,
597 ) -> Result<Self::Inserter<'b>, IncompatibleError>;
598
599 fn could_insert(&self, new_sharing_state: &Self::SharingState)
604 -> Result<(), IncompatibleError>;
605
606 fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
610}
611
612pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
614 fn try_update_sharing(
617 &mut self,
618 id: Self::Id,
619 new_sharing_state: &Self::SharingState,
620 ) -> Result<(), IncompatibleError>;
621}
622
623pub trait SocketMapConflictPolicy<
625 Addr,
626 SharingState,
627 I: Ip,
628 D: DeviceIdentifier,
629 A: SocketMapAddrSpec,
630>: SocketMapStateSpec
631{
632 fn check_insert_conflicts(
641 new_sharing_state: &SharingState,
642 addr: &Addr,
643 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
644 ) -> Result<(), InsertError>;
645}
646
647pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
650 SocketMapConflictPolicy<Addr, SharingState, I, D, A>
651where
652 A: SocketMapAddrSpec,
653{
654 fn allows_sharing_update(
657 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
658 addr: &Addr,
659 old_sharing: &SharingState,
660 new_sharing: &SharingState,
661 ) -> Result<(), UpdateSharingError>;
662}
663
664#[derive(Derivative)]
666#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
667#[allow(missing_docs)]
668pub enum Bound<S: SocketMapStateSpec + ?Sized> {
669 Listen(S::ListenerAddrState),
670 Conn(S::ConnAddrState),
671}
672
673#[derive(Derivative)]
688#[derivative(
689 Debug(bound = "D: Debug"),
690 Clone(bound = "D: Clone"),
691 Eq(bound = "D: Eq"),
692 PartialEq(bound = "D: PartialEq"),
693 Hash(bound = "D: Hash")
694)]
695#[allow(missing_docs)]
696pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
697 Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
698 Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
699}
700
701impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
702 Tagged<AddrVec<I, D, A>> for Bound<S>
703{
704 type Tag = S::AddrVecTag;
705 fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
706 match (self, address) {
707 (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
708 (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
709 S::connected_tag(device.is_some(), c)
710 }
711 (Bound::Listen(_), AddrVec::Conn(_)) => {
712 unreachable!("found listen state for conn addr")
713 }
714 (Bound::Conn(_), AddrVec::Listen(_)) => {
715 unreachable!("found conn state for listen addr")
716 }
717 }
718 }
719}
720
721impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
722 type IterShadows = AddrVecIter<I, D, A>;
723
724 fn iter_shadows(&self) -> Self::IterShadows {
725 let (socket_ip_addr, device) = match self.clone() {
726 AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
727 AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
728 };
729 let mut iter = match device {
730 Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
731 None => AddrVecIter::without_device(socket_ip_addr),
732 };
733 assert_eq!(iter.next().as_ref(), Some(self));
735 iter
736 }
737}
738
739#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
741#[allow(missing_docs)]
742pub enum SocketAddrType {
743 AnyListener,
744 SpecificListener,
745 Connected,
746}
747
748impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
749 fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
750 match addr {
751 Some(_) => SocketAddrType::SpecificListener,
752 None => SocketAddrType::AnyListener,
753 }
754 }
755}
756
757impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
758 fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
759 SocketAddrType::Connected
760 }
761}
762
763pub enum RemoveResult {
765 Success,
767 IsLast,
770}
771
772#[derive(Derivative)]
773#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
774pub enum SocketId<S: SocketMapStateSpec> {
775 Listener(S::ListenerId),
776 Connection(S::ConnId),
777}
778
779#[derive(Derivative)]
793#[derivative(Default(bound = ""))]
794pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
795 addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
796}
797
798impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
799 BoundSocketMap<I, D, A, S>
800{
801 pub fn len(&self) -> usize {
803 self.addr_to_state.len()
804 }
805}
806
807pub enum Listener {}
809pub enum Connection {}
811
812pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
814
815impl<
816 'a,
817 I: Ip,
818 D: DeviceIdentifier,
819 SocketType: ConvertSocketMapState<I, D, A, S>,
820 A: SocketMapAddrSpec,
821 S: SocketMapStateSpec,
822> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
823where
824 S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
825{
826 pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
828 let Self(addr_to_state, _marker) = self;
829 addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
830 SocketType::from_bound_ref(state)
831 .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
832 })
833 }
834
835 pub fn could_insert(
841 self,
842 addr: &SocketType::Addr,
843 sharing: &SocketType::SharingState,
844 ) -> Result<(), InsertError> {
845 let Self(addr_to_state, _) = self;
846 match self.get_by_addr(addr) {
847 Some(state) => {
848 state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
849 }
850 None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
851 }
852 }
853}
854
855#[derive(Derivative)]
857#[derivative(Debug(bound = ""))]
858pub struct SocketStateEntry<
859 'a,
860 I: Ip,
861 D: DeviceIdentifier,
862 A: SocketMapAddrSpec,
863 S: SocketMapStateSpec,
864 SocketType,
865> {
866 id: SocketId<S>,
867 addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
868 _marker: PhantomData<SocketType>,
869}
870
871impl<
872 'a,
873 I: Ip,
874 D: DeviceIdentifier,
875 SocketType: ConvertSocketMapState<I, D, A, S>,
876 A: SocketMapAddrSpec,
877 S: SocketMapStateSpec
878 + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
879> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
880where
881 SocketType::SharingState: Clone,
882 SocketType::Id: Clone,
883{
884 pub fn try_insert(
887 self,
888 socket_addr: SocketType::Addr,
889 tag_state: SocketType::SharingState,
890 id: SocketType::Id,
891 ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, InsertError> {
892 self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
893 .map(|(entry, ())| entry)
894 }
895
896 pub fn try_insert_with<R>(
901 self,
902 socket_addr: SocketType::Addr,
903 tag_state: SocketType::SharingState,
904 make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
905 ) -> Result<(SocketStateEntry<'a, I, D, A, S, SocketType>, R), InsertError> {
906 let Self(addr_to_state, _) = self;
907 S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state)?;
908
909 let addr = SocketType::to_addr_vec(&socket_addr);
910
911 match addr_to_state.entry(addr) {
912 Entry::Occupied(mut o) => {
913 let (id, ret) = o.map_mut(|bound| {
914 let bound = match SocketType::from_bound_mut(bound) {
915 Some(bound) => bound,
916 None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
917 };
918 match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
919 bound, &tag_state,
920 ) {
921 Ok(v) => {
922 let (id, ret) = make_id(socket_addr, tag_state);
923 v.insert(id.clone());
924 Ok((SocketType::to_socket_id(id), ret))
925 }
926 Err(IncompatibleError) => Err(InsertError::Exists),
927 }
928 })?;
929 Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
930 }
931 Entry::Vacant(v) => {
932 let (id, ret) = make_id(socket_addr, tag_state.clone());
933 let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
934 &tag_state,
935 id.clone(),
936 )));
937 let id = SocketType::to_socket_id(id);
938 Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
939 }
940 }
941 }
942
943 pub fn entry(
945 self,
946 id: &SocketType::Id,
947 addr: &SocketType::Addr,
948 ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
949 let Self(addr_to_state, _) = self;
950 let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
951 Entry::Vacant(_) => return None,
952 Entry::Occupied(o) => o,
953 };
954 let state = SocketType::from_bound_ref(addr_entry.get())?;
955
956 state.contains_id(id).then_some(SocketStateEntry {
957 id: SocketType::to_socket_id(id.clone()),
958 addr_entry,
959 _marker: PhantomData::default(),
960 })
961 }
962
963 pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
965 self.entry(id, addr)
966 .map(|entry| {
967 entry.remove();
968 })
969 .ok_or(NotFoundError)
970 }
971}
972
973#[derive(Debug)]
976pub struct UpdateSharingError;
977
978impl<
979 'a,
980 I: Ip,
981 D: DeviceIdentifier,
982 SocketType: ConvertSocketMapState<I, D, A, S>,
983 A: SocketMapAddrSpec,
984 S: SocketMapStateSpec,
985> SocketStateEntry<'a, I, D, A, S, SocketType>
986where
987 SocketType::Id: Clone,
988{
989 pub fn get_addr(&self) -> &SocketType::Addr {
991 let Self { id: _, addr_entry, _marker } = self;
992 SocketType::from_addr_vec_ref(addr_entry.key())
993 }
994
995 pub fn id(&self) -> &SocketType::Id {
997 let Self { id, addr_entry: _, _marker } = self;
998 SocketType::from_socket_id_ref(id)
999 }
1000
1001 pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
1003 let Self { id, addr_entry, _marker } = self;
1004
1005 let new_addrvec = SocketType::to_addr_vec(&new_addr);
1006 let old_addr = addr_entry.key().clone();
1007 let (addr_state, addr_to_state) = addr_entry.remove_from_map();
1008 let addr_to_state = match addr_to_state.entry(new_addrvec) {
1009 Entry::Occupied(o) => o.into_map(),
1010 Entry::Vacant(v) => {
1011 if v.descendant_counts().len() != 0 {
1012 v.into_map()
1013 } else {
1014 let new_addr_entry = v.insert(addr_state);
1015 return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
1016 }
1017 }
1018 };
1019 let to_restore = addr_state;
1020 let addr_entry = match addr_to_state.entry(old_addr) {
1022 Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
1023 Entry::Vacant(v) => v.insert(to_restore),
1024 };
1025 return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
1026 }
1027
1028 pub fn remove(self) {
1030 let Self { id, mut addr_entry, _marker } = self;
1031 let addr = addr_entry.key().clone();
1032 match addr_entry.map_mut(|value| {
1033 let value = match SocketType::from_bound_mut(value) {
1034 Some(value) => value,
1035 None => unreachable!("found {:?} for address {:?}", value, addr),
1036 };
1037 value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
1038 }) {
1039 RemoveResult::Success => (),
1040 RemoveResult::IsLast => {
1041 let _: Bound<S> = addr_entry.remove();
1042 }
1043 }
1044 }
1045
1046 pub fn try_update_sharing(
1048 &mut self,
1049 old_sharing_state: &SocketType::SharingState,
1050 new_sharing_state: SocketType::SharingState,
1051 ) -> Result<(), UpdateSharingError>
1052 where
1053 SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1054 S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1055 {
1056 let Self { id, addr_entry, _marker } = self;
1057 let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1058
1059 S::allows_sharing_update(
1060 addr_entry.get_map(),
1061 addr,
1062 old_sharing_state,
1063 &new_sharing_state,
1064 )?;
1065
1066 addr_entry
1067 .map_mut(|value| {
1068 let value = match SocketType::from_bound_mut(value) {
1069 Some(value) => value,
1070 None => unreachable!("found invalid state {:?}", value),
1074 };
1075
1076 value.try_update_sharing(
1077 SocketType::from_socket_id_ref(id).clone(),
1078 &new_sharing_state,
1079 )
1080 })
1081 .map_err(|IncompatibleError| UpdateSharingError)
1082 }
1083}
1084
1085impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1086where
1087 AddrVec<I, D, A>: IterShadows,
1088 S: SocketMapStateSpec,
1089{
1090 pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1092 where
1093 S: SocketMapConflictPolicy<
1094 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1095 <S as SocketMapStateSpec>::ListenerSharingState,
1096 I,
1097 D,
1098 A,
1099 >,
1100 S::ListenerAddrState:
1101 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1102 {
1103 let Self { addr_to_state } = self;
1104 Sockets(addr_to_state, Default::default())
1105 }
1106
1107 pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1109 where
1110 S: SocketMapConflictPolicy<
1111 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1112 <S as SocketMapStateSpec>::ListenerSharingState,
1113 I,
1114 D,
1115 A,
1116 >,
1117 S::ListenerAddrState:
1118 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1119 {
1120 let Self { addr_to_state } = self;
1121 Sockets(addr_to_state, Default::default())
1122 }
1123
1124 pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1126 where
1127 S: SocketMapConflictPolicy<
1128 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1129 <S as SocketMapStateSpec>::ConnSharingState,
1130 I,
1131 D,
1132 A,
1133 >,
1134 S::ConnAddrState:
1135 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1136 {
1137 let Self { addr_to_state } = self;
1138 Sockets(addr_to_state, Default::default())
1139 }
1140
1141 pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1143 where
1144 S: SocketMapConflictPolicy<
1145 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1146 <S as SocketMapStateSpec>::ConnSharingState,
1147 I,
1148 D,
1149 A,
1150 >,
1151 S::ConnAddrState:
1152 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1153 {
1154 let Self { addr_to_state } = self;
1155 Sockets(addr_to_state, Default::default())
1156 }
1157
1158 #[cfg(test)]
1159 pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1160 let Self { addr_to_state } = self;
1161 addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1162 }
1163
1164 pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1166 let Self { addr_to_state } = self;
1167 addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1168 }
1169}
1170
1171pub enum FoundSockets<A, It> {
1173 Single(A),
1175 Multicast(It),
1178}
1179
1180#[allow(missing_docs)]
1182#[derive(Debug)]
1183pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1184 Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1185 Conn(
1186 &'a S::ConnAddrState,
1187 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1188 ),
1189}
1190
1191impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1192where
1193 I: BroadcastIpExt<Addr: MulticastAddress>,
1194 D: DeviceIdentifier,
1195 A: SocketMapAddrSpec,
1196 S: SocketMapStateSpec
1197 + SocketMapConflictPolicy<
1198 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1199 <S as SocketMapStateSpec>::ListenerSharingState,
1200 I,
1201 D,
1202 A,
1203 > + SocketMapConflictPolicy<
1204 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1205 <S as SocketMapStateSpec>::ConnSharingState,
1206 I,
1207 D,
1208 A,
1209 >,
1210{
1211 pub fn lookup_connected(
1217 &self,
1218 (src_ip, src_port): (SocketIpAddr<I::Addr>, A::RemoteIdentifier),
1219 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1220 device: D,
1221 ) -> Option<&'_ S::ConnAddrState> {
1222 let mut addr = ConnAddr {
1223 ip: ConnIpAddr { local: (dst_ip, dst_port), remote: (src_ip, src_port) },
1224 device: Some(device),
1225 };
1226 let entry = self.conns().get_by_addr(&addr);
1227 if entry.is_some() {
1228 return entry;
1229 }
1230 addr.device = None;
1231 self.conns().get_by_addr(&addr)
1232 }
1233
1234 pub fn iter_receivers(
1240 &self,
1241 (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1242 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1243 device: D,
1244 broadcast: Option<I::BroadcastMarker>,
1245 ) -> Option<
1246 FoundSockets<
1247 AddrEntry<'_, I, D, A, S>,
1248 impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1249 >,
1250 > {
1251 let mut matching_entries = AddrVecIter::with_device(
1252 match (src_ip, src_port) {
1253 (Some(specified_src_ip), Some(src_port)) => {
1254 ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1255 .into()
1256 }
1257 _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1258 },
1259 device,
1260 )
1261 .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1262 AddrVec::Listen(l) => {
1263 self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1264 }
1265 AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1266 });
1267
1268 if broadcast.is_some() || dst_ip.addr().is_multicast() {
1269 Some(FoundSockets::Multicast(matching_entries))
1270 } else {
1271 let single_entry: Option<_> = matching_entries.next();
1272 single_entry.map(FoundSockets::Single)
1273 }
1274 }
1275}
1276
1277#[derive(Debug, Eq, PartialEq)]
1279pub enum InsertError {
1280 ShadowAddrExists,
1282 Exists,
1284 WouldShadowExisting,
1286 IndirectConflict,
1288}
1289
1290impl From<InsertError> for LocalAddressError {
1291 fn from(value: InsertError) -> Self {
1292 match value {
1293 InsertError::ShadowAddrExists
1294 | InsertError::Exists
1295 | InsertError::IndirectConflict
1296 | InsertError::WouldShadowExisting => LocalAddressError::AddressInUse,
1297 }
1298 }
1299}
1300
1301pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1304 type Id;
1305 type SharingState;
1306 type Addr: Debug;
1307 type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1308
1309 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1310 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1311 fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1312 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1313 fn to_bound(state: Self::AddrState) -> Bound<S>;
1314 fn to_socket_id(id: Self::Id) -> SocketId<S>;
1315 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1316}
1317
1318impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1319 ConvertSocketMapState<I, D, A, S> for Listener
1320{
1321 type Id = S::ListenerId;
1322 type SharingState = S::ListenerSharingState;
1323 type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1324 type AddrState = S::ListenerAddrState;
1325 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1326 AddrVec::Listen(addr.clone())
1327 }
1328
1329 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1330 match addr {
1331 AddrVec::Listen(l) => l,
1332 AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1333 }
1334 }
1335
1336 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1337 match bound {
1338 Bound::Listen(l) => Some(l),
1339 Bound::Conn(_c) => None,
1340 }
1341 }
1342
1343 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1344 match bound {
1345 Bound::Listen(l) => Some(l),
1346 Bound::Conn(_c) => None,
1347 }
1348 }
1349
1350 fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1351 Bound::Listen(state)
1352 }
1353 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1354 match id {
1355 SocketId::Listener(id) => id,
1356 SocketId::Connection(_) => unreachable!("connection ID for listener"),
1357 }
1358 }
1359 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1360 SocketId::Listener(id)
1361 }
1362}
1363
1364impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1365 ConvertSocketMapState<I, D, A, S> for Connection
1366{
1367 type Id = S::ConnId;
1368 type SharingState = S::ConnSharingState;
1369 type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1370 type AddrState = S::ConnAddrState;
1371 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1372 AddrVec::Conn(addr.clone())
1373 }
1374
1375 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1376 match addr {
1377 AddrVec::Conn(c) => c,
1378 AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1379 }
1380 }
1381
1382 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1383 match bound {
1384 Bound::Listen(_l) => None,
1385 Bound::Conn(c) => Some(c),
1386 }
1387 }
1388
1389 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1390 match bound {
1391 Bound::Listen(_l) => None,
1392 Bound::Conn(c) => Some(c),
1393 }
1394 }
1395
1396 fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1397 Bound::Conn(state)
1398 }
1399
1400 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1401 match id {
1402 SocketId::Connection(id) => id,
1403 SocketId::Listener(_) => unreachable!("listener ID for connection"),
1404 }
1405 }
1406 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1407 SocketId::Connection(id)
1408 }
1409}
1410
1411#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)]
1413pub struct SharingDomain(u64);
1414
1415impl SharingDomain {
1416 pub const fn new(id: u64) -> Self {
1420 SharingDomain(id)
1421 }
1422}
1423
1424#[derive(Default, Debug, Eq, PartialEq, Clone, Copy, Hash)]
1427pub enum ReusePortOption {
1428 #[default]
1430 Disabled,
1431
1432 Enabled(SharingDomain),
1435}
1436
1437impl ReusePortOption {
1438 pub fn is_enabled(&self) -> bool {
1440 matches!(self, ReusePortOption::Enabled(_))
1441 }
1442
1443 pub fn is_shareable_with(&self, other: &Self) -> bool {
1446 match (self, other) {
1447 (ReusePortOption::Enabled(domain1), ReusePortOption::Enabled(domain2)) => {
1448 domain1 == domain2
1449 }
1450 _ => false,
1451 }
1452 }
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457 use alloc::vec;
1458 use alloc::vec::Vec;
1459
1460 use assert_matches::assert_matches;
1461 use net_declare::{net_ip_v4, net_ip_v6};
1462 use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1463 use netstack3_hashmap::HashSet;
1464 use test_case::test_case;
1465
1466 use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1467 use crate::testutil::set_logger_for_test;
1468
1469 use super::*;
1470
1471 #[test_case(net_ip_v4!("8.8.8.8"))]
1472 #[test_case(net_ip_v4!("127.0.0.1"))]
1473 #[test_case(net_ip_v4!("127.0.8.9"))]
1474 #[test_case(net_ip_v4!("224.1.2.3"))]
1475 fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1476 let addr = SpecifiedAddr::new(addr).unwrap();
1478 assert_eq!(addr.must_have_zone(), false);
1479 }
1480
1481 #[test_case(net_ip_v6!("1::2:3"), false)]
1482 #[test_case(net_ip_v6!("::1"), false; "localhost")]
1483 #[test_case(net_ip_v6!("1::"), false)]
1484 #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1485 #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1486 #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1487 #[test_case(net_ip_v6!("fe80::1"), true)]
1488 fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1489 let addr = SpecifiedAddr::new(addr).unwrap();
1492 assert_eq!(addr.must_have_zone(), must_have);
1493 }
1494
1495 #[test]
1496 fn try_into_null_zoned_ipv6() {
1497 assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1498 let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1499 const ZONE: u32 = 5;
1500 assert_eq!(
1501 zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1502 Some(AddrAndZone::new(zoned, ZONE).unwrap())
1503 );
1504 }
1505
1506 enum FakeSpec {}
1507
1508 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1509 struct Listener(usize);
1510
1511 #[derive(PartialEq, Eq, Debug, Copy, Clone)]
1512 struct SharingState {
1513 tag: char,
1514 shared: bool,
1515 }
1516
1517 impl SharingState {
1518 fn exclusive(tag: char) -> Self {
1519 Self { tag, shared: false }
1520 }
1521
1522 fn shared(tag: char) -> Self {
1523 Self { tag, shared: true }
1524 }
1525 }
1526
1527 impl SharingState {
1528 fn can_share_with(&self, other: &Self) -> bool {
1529 self.tag == other.tag && self.shared && other.shared
1530 }
1531 }
1532
1533 #[derive(PartialEq, Eq, Debug)]
1534 struct Multiple<T> {
1535 sharing_state: SharingState,
1536 entries: Vec<T>,
1537 }
1538
1539 impl<T> Multiple<T> {
1540 fn new_exclusive(tag: char, entries: Vec<T>) -> Self {
1541 Self { sharing_state: SharingState { tag, shared: false }, entries }
1542 }
1543 }
1544
1545 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1546 struct Conn(usize);
1547
1548 enum FakeAddrSpec {}
1549
1550 impl SocketMapAddrSpec for FakeAddrSpec {
1551 type LocalIdentifier = NonZeroU16;
1552 type RemoteIdentifier = ();
1553 }
1554
1555 impl SocketMapStateSpec for FakeSpec {
1556 type AddrVecTag = SharingState;
1557
1558 type ListenerId = Listener;
1559 type ConnId = Conn;
1560
1561 type ListenerSharingState = SharingState;
1562 type ConnSharingState = SharingState;
1563
1564 type ListenerAddrState = Multiple<Listener>;
1565 type ConnAddrState = Multiple<Conn>;
1566
1567 fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1568 state.sharing_state
1569 }
1570
1571 fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1572 state.sharing_state
1573 }
1574 }
1575
1576 type FakeBoundSocketMap =
1577 BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1578
1579 #[derive(Default)]
1583 struct FakeSocketIdGen {
1584 next_id: usize,
1585 }
1586
1587 impl FakeSocketIdGen {
1588 fn next(&mut self) -> usize {
1589 let next_next_id = self.next_id + 1;
1590 core::mem::replace(&mut self.next_id, next_next_id)
1591 }
1592 }
1593
1594 impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1595 type Id = I;
1596 type SharingState = SharingState;
1597 type Inserter<'a>
1598 = &'a mut Vec<I>
1599 where
1600 I: 'a;
1601
1602 fn new(sharing_state: &SharingState, id: I) -> Self {
1603 Self { sharing_state: *sharing_state, entries: vec![id] }
1604 }
1605
1606 fn contains_id(&self, id: &Self::Id) -> bool {
1607 self.entries.contains(id)
1608 }
1609
1610 fn try_get_inserter<'a, 'b>(
1611 &'b mut self,
1612 new_sharing_state: &'a SharingState,
1613 ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1614 (self.sharing_state == *new_sharing_state)
1615 .then_some(&mut self.entries)
1616 .ok_or(IncompatibleError)
1617 }
1618
1619 fn could_insert(&self, new_sharing_state: &SharingState) -> Result<(), IncompatibleError> {
1620 (self.sharing_state == *new_sharing_state).then_some(()).ok_or(IncompatibleError)
1621 }
1622
1623 fn remove_by_id(&mut self, id: I) -> RemoveResult {
1624 let index = self.entries.iter().position(|i| i == &id).expect("did not find id");
1625 let _: I = self.entries.swap_remove(index);
1626 if self.entries.is_empty() { RemoveResult::IsLast } else { RemoveResult::Success }
1627 }
1628 }
1629
1630 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1631 SocketMapConflictPolicy<A, SharingState, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1632 for FakeSpec
1633 {
1634 fn check_insert_conflicts(
1635 new_sharing_state: &SharingState,
1636 addr: &A,
1637 socketmap: &SocketMap<
1638 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1639 Bound<FakeSpec>,
1640 >,
1641 ) -> Result<(), InsertError> {
1642 let dest: AddrVec<_, _, _> = addr.clone().into();
1643 if dest.iter_shadows().any(|a| {
1644 let entry = socketmap.get(&a);
1645 match entry {
1646 Some(Bound::Listen(Multiple { sharing_state, .. }))
1647 | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1648 !sharing_state.can_share_with(new_sharing_state)
1649 }
1650 None => false,
1651 }
1652 }) {
1653 return Err(InsertError::ShadowAddrExists);
1654 }
1655
1656 match socketmap.get(&dest) {
1657 Some(Bound::Listen(Multiple { sharing_state, .. }))
1658 | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1659 if sharing_state != new_sharing_state {
1662 return Err(InsertError::Exists);
1663 }
1664 }
1665 None => (),
1666 }
1667
1668 if socketmap
1669 .descendant_counts(&dest)
1670 .any(|(sharing_state, _count)| !sharing_state.can_share_with(new_sharing_state))
1671 {
1672 Err(InsertError::WouldShadowExisting)
1673 } else {
1674 Ok(())
1675 }
1676 }
1677 }
1678
1679 impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1680 fn try_update_sharing(
1681 &mut self,
1682 id: Self::Id,
1683 new_sharing_state: &Self::SharingState,
1684 ) -> Result<(), IncompatibleError> {
1685 if self.sharing_state == *new_sharing_state {
1686 return Ok(());
1687 }
1688
1689 if self.entries.len() != 1 {
1694 return Err(IncompatibleError);
1695 }
1696 assert!(self.entries.contains(&id));
1697 self.sharing_state = *new_sharing_state;
1698 Ok(())
1699 }
1700 }
1701
1702 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1703 SocketMapUpdateSharingPolicy<
1704 A,
1705 SharingState,
1706 Ipv4,
1707 FakeWeakDeviceId<FakeDeviceId>,
1708 FakeAddrSpec,
1709 > for FakeSpec
1710 {
1711 fn allows_sharing_update(
1712 _socketmap: &SocketMap<
1713 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1714 Bound<Self>,
1715 >,
1716 _addr: &A,
1717 _old_sharing: &SharingState,
1718 _new_sharing_state: &SharingState,
1719 ) -> Result<(), UpdateSharingError> {
1720 Ok(())
1721 }
1722 }
1723
1724 const LISTENER_ADDR: ListenerAddr<
1725 ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1726 FakeWeakDeviceId<FakeDeviceId>,
1727 > = ListenerAddr {
1728 ip: ListenerIpAddr {
1729 addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1730 identifier: NonZeroU16::new(1).unwrap(),
1731 },
1732 device: None,
1733 };
1734
1735 const CONN_ADDR: ConnAddr<
1736 ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1737 FakeWeakDeviceId<FakeDeviceId>,
1738 > = ConnAddr {
1739 ip: ConnIpAddr {
1740 local: (
1741 unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1742 NonZeroU16::new(1).unwrap(),
1743 ),
1744 remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1745 },
1746 device: None,
1747 };
1748
1749 #[test]
1750 fn bound_insert_get_remove_listener() {
1751 set_logger_for_test();
1752 let mut bound = FakeBoundSocketMap::default();
1753 let mut fake_id_gen = FakeSocketIdGen::default();
1754
1755 let addr = LISTENER_ADDR;
1756
1757 let id = {
1758 let entry = bound
1759 .listeners_mut()
1760 .try_insert(addr, SharingState::exclusive('v'), Listener(fake_id_gen.next()))
1761 .unwrap();
1762 assert_eq!(entry.get_addr(), &addr);
1763 entry.id().clone()
1764 };
1765
1766 assert_eq!(
1767 bound.listeners().get_by_addr(&addr),
1768 Some(&Multiple::new_exclusive('v', vec![id]))
1769 );
1770
1771 assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1772 assert_eq!(bound.listeners().get_by_addr(&addr), None);
1773 }
1774
1775 #[test]
1776 fn bound_insert_get_remove_conn() {
1777 set_logger_for_test();
1778 let mut bound = FakeBoundSocketMap::default();
1779 let mut fake_id_gen = FakeSocketIdGen::default();
1780
1781 let addr = CONN_ADDR;
1782
1783 let id = {
1784 let entry = bound
1785 .conns_mut()
1786 .try_insert(addr, SharingState::exclusive('v'), Conn(fake_id_gen.next()))
1787 .unwrap();
1788 assert_eq!(entry.get_addr(), &addr);
1789 entry.id().clone()
1790 };
1791
1792 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('v', vec![id])));
1793
1794 assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1795 assert_eq!(bound.conns().get_by_addr(&addr), None);
1796 }
1797
1798 #[test]
1799 fn bound_iter_addrs() {
1800 set_logger_for_test();
1801 let mut bound = FakeBoundSocketMap::default();
1802 let mut fake_id_gen = FakeSocketIdGen::default();
1803
1804 let listener_addrs = [
1805 (Some(net_ip_v4!("1.1.1.1")), 1),
1806 (Some(net_ip_v4!("2.2.2.2")), 2),
1807 (Some(net_ip_v4!("1.1.1.1")), 3),
1808 (None, 4),
1809 ]
1810 .map(|(ip, identifier)| ListenerAddr {
1811 device: None,
1812 ip: ListenerIpAddr {
1813 addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1814 identifier: NonZeroU16::new(identifier).unwrap(),
1815 },
1816 });
1817 let conn_addrs = [
1818 (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1819 (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1820 ]
1821 .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1822 ip: ConnIpAddr {
1823 local: (
1824 SocketIpAddr::new(local_ip).unwrap(),
1825 NonZeroU16::new(local_identifier).unwrap(),
1826 ),
1827 remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1828 },
1829 device: None,
1830 });
1831
1832 for addr in listener_addrs.iter().cloned() {
1833 let _entry = bound
1834 .listeners_mut()
1835 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1836 .unwrap();
1837 }
1838 for addr in conn_addrs.iter().cloned() {
1839 let _entry = bound
1840 .conns_mut()
1841 .try_insert(addr, SharingState::exclusive('a'), Conn(fake_id_gen.next()))
1842 .unwrap();
1843 }
1844 let expected_addrs = listener_addrs
1845 .into_iter()
1846 .map(Into::into)
1847 .chain(conn_addrs.into_iter().map(Into::into))
1848 .collect::<HashSet<_>>();
1849
1850 assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1851 }
1852
1853 #[test]
1854 fn try_insert_with_callback_not_called_on_error() {
1855 set_logger_for_test();
1858 let mut bound = FakeBoundSocketMap::default();
1859 let addr = LISTENER_ADDR;
1860
1861 let _: &Listener = bound
1863 .listeners_mut()
1864 .try_insert(addr, SharingState::exclusive('a'), Listener(0))
1865 .unwrap()
1866 .id();
1867
1868 fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1872 panic!("should never be called");
1873 }
1874
1875 assert_matches!(
1876 bound.listeners_mut().try_insert_with(
1877 addr,
1878 SharingState::exclusive('b'),
1879 is_never_called
1880 ),
1881 Err(InsertError::Exists)
1882 );
1883 assert_matches!(
1884 bound.listeners_mut().try_insert_with(
1885 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1886 SharingState::exclusive('b'),
1887 is_never_called
1888 ),
1889 Err(InsertError::ShadowAddrExists)
1890 );
1891 assert_matches!(
1892 bound.conns_mut().try_insert_with(
1893 ConnAddr {
1894 device: None,
1895 ip: ConnIpAddr {
1896 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1897 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1898 },
1899 },
1900 SharingState::exclusive('b'),
1901 is_never_called,
1902 ),
1903 Err(InsertError::ShadowAddrExists)
1904 );
1905 }
1906
1907 #[test]
1908 fn insert_listener_conflict_with_listener() {
1909 set_logger_for_test();
1910 let mut bound = FakeBoundSocketMap::default();
1911 let mut fake_id_gen = FakeSocketIdGen::default();
1912 let addr = LISTENER_ADDR;
1913
1914 let _: &Listener = bound
1915 .listeners_mut()
1916 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1917 .unwrap()
1918 .id();
1919 assert_matches!(
1920 bound.listeners_mut().try_insert(
1921 addr,
1922 SharingState::exclusive('b'),
1923 Listener(fake_id_gen.next())
1924 ),
1925 Err(InsertError::Exists)
1926 );
1927 }
1928
1929 #[test]
1930 fn insert_listener_conflict_with_shadower() {
1931 set_logger_for_test();
1932 let mut bound = FakeBoundSocketMap::default();
1933 let mut fake_id_gen = FakeSocketIdGen::default();
1934 let addr = LISTENER_ADDR;
1935 let shadows_addr = {
1936 assert_eq!(addr.device, None);
1937 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1938 };
1939
1940 let _: &Listener = bound
1941 .listeners_mut()
1942 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1943 .unwrap()
1944 .id();
1945 assert_matches!(
1946 bound.listeners_mut().try_insert(
1947 shadows_addr,
1948 SharingState::exclusive('b'),
1949 Listener(fake_id_gen.next())
1950 ),
1951 Err(InsertError::ShadowAddrExists)
1952 );
1953 }
1954
1955 #[test]
1956 fn insert_conn_conflict_with_listener() {
1957 set_logger_for_test();
1958 let mut bound = FakeBoundSocketMap::default();
1959 let mut fake_id_gen = FakeSocketIdGen::default();
1960 let addr = LISTENER_ADDR;
1961 let shadows_addr = ConnAddr {
1962 device: None,
1963 ip: ConnIpAddr {
1964 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1965 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1966 },
1967 };
1968
1969 let _: &Listener = bound
1970 .listeners_mut()
1971 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1972 .unwrap()
1973 .id();
1974 assert_matches!(
1975 bound.conns_mut().try_insert(
1976 shadows_addr,
1977 SharingState::exclusive('b'),
1978 Conn(fake_id_gen.next())
1979 ),
1980 Err(InsertError::ShadowAddrExists)
1981 );
1982 }
1983
1984 #[test]
1985 fn insert_and_remove_listener() {
1986 set_logger_for_test();
1987 let mut bound = FakeBoundSocketMap::default();
1988 let mut fake_id_gen = FakeSocketIdGen::default();
1989 let addr = LISTENER_ADDR;
1990
1991 let a = bound
1992 .listeners_mut()
1993 .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1994 .unwrap()
1995 .id()
1996 .clone();
1997 let b = bound
1998 .listeners_mut()
1999 .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
2000 .unwrap()
2001 .id()
2002 .clone();
2003 assert_ne!(a, b);
2004
2005 assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
2006 assert_eq!(
2007 bound.listeners().get_by_addr(&addr),
2008 Some(&Multiple::new_exclusive('x', vec![b]))
2009 );
2010 }
2011
2012 #[test]
2013 fn insert_and_remove_conn() {
2014 set_logger_for_test();
2015 let mut bound = FakeBoundSocketMap::default();
2016 let mut fake_id_gen = FakeSocketIdGen::default();
2017 let addr = CONN_ADDR;
2018
2019 let a = bound
2020 .conns_mut()
2021 .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
2022 .unwrap()
2023 .id()
2024 .clone();
2025 let b = bound
2026 .conns_mut()
2027 .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
2028 .unwrap()
2029 .id()
2030 .clone();
2031 assert_ne!(a, b);
2032
2033 assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
2034 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('x', vec![b])));
2035 }
2036
2037 #[test]
2038 fn update_listener_to_shadowed_addr_fails() {
2039 let mut bound = FakeBoundSocketMap::default();
2040 let mut fake_id_gen = FakeSocketIdGen::default();
2041
2042 let first_addr = LISTENER_ADDR;
2043 let second_addr = ListenerAddr {
2044 ip: ListenerIpAddr {
2045 addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
2046 ..LISTENER_ADDR.ip
2047 },
2048 ..LISTENER_ADDR
2049 };
2050 let both_shadow = ListenerAddr {
2051 ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
2052 device: None,
2053 };
2054
2055 let first = bound
2056 .listeners_mut()
2057 .try_insert(first_addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
2058 .unwrap()
2059 .id()
2060 .clone();
2061 let second = bound
2062 .listeners_mut()
2063 .try_insert(second_addr, SharingState::exclusive('b'), Listener(fake_id_gen.next()))
2064 .unwrap()
2065 .id()
2066 .clone();
2067
2068 let (ExistsError, entry) = bound
2071 .listeners_mut()
2072 .entry(&second, &second_addr)
2073 .unwrap()
2074 .try_update_addr(both_shadow)
2075 .expect_err("update should fail");
2076
2077 assert_eq!(entry.id(), &second);
2079 drop(entry);
2080
2081 let (ExistsError, entry) = bound
2082 .listeners_mut()
2083 .entry(&first, &first_addr)
2084 .unwrap()
2085 .try_update_addr(both_shadow)
2086 .expect_err("update should fail");
2087 assert_eq!(entry.get_addr(), &first_addr);
2088 }
2089
2090 #[test]
2091 fn nonexistent_conn_entry() {
2092 let mut map = FakeBoundSocketMap::default();
2093 let mut fake_id_gen = FakeSocketIdGen::default();
2094 let addr = CONN_ADDR;
2095 let conn_id = map
2096 .conns_mut()
2097 .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2098 .expect("failed to insert")
2099 .id()
2100 .clone();
2101 assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
2102
2103 assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
2104 }
2105
2106 #[test]
2107 fn update_conn_sharing() {
2108 let mut map = FakeBoundSocketMap::default();
2109 let mut fake_id_gen = FakeSocketIdGen::default();
2110 let addr = CONN_ADDR;
2111 let mut entry = map
2112 .conns_mut()
2113 .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2114 .expect("failed to insert");
2115
2116 entry
2117 .try_update_sharing(&SharingState::exclusive('a'), SharingState::exclusive('d'))
2118 .expect("worked");
2119 let mut second_conn = map
2122 .conns_mut()
2123 .try_insert(addr.clone(), SharingState::exclusive('d'), Conn(fake_id_gen.next()))
2124 .expect("can insert");
2125 assert_matches!(
2126 second_conn
2127 .try_update_sharing(&SharingState::exclusive('d'), SharingState::exclusive('e')),
2128 Err(UpdateSharingError)
2129 );
2130 }
2131
2132 #[test]
2133 fn lookup_connected() {
2134 let mut map = FakeBoundSocketMap::default();
2135 let mut fake_id_gen = FakeSocketIdGen::default();
2136
2137 let sharing_state = SharingState::shared('a');
2138
2139 let device_id = FakeWeakDeviceId(FakeDeviceId);
2140 let entry1 = map
2141 .conns_mut()
2142 .try_insert(CONN_ADDR, sharing_state, Conn(fake_id_gen.next()))
2143 .expect("failed to insert")
2144 .id()
2145 .clone();
2146 let conn = map
2147 .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2148 .expect("lookup should succeed");
2149 assert!(conn.contains_id(&entry1));
2150
2151 let addr_with_device = ConnAddr { device: Some(device_id), ..CONN_ADDR };
2154 let entry2 = map
2155 .conns_mut()
2156 .try_insert(addr_with_device, sharing_state, Conn(fake_id_gen.next()))
2157 .expect("failed to insert")
2158 .id()
2159 .clone();
2160 let conn = map
2161 .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2162 .expect("lookup should succeed");
2163 assert!(conn.contains_id(&entry2));
2164 }
2165}