1use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17 AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19use thiserror::Error;
20
21use crate::data_structures::socketmap::{
22 Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
23};
24use crate::device::{
25 DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
26};
27use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
28use crate::ip::BroadcastIpExt;
29use crate::socket::address::{
30 AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
31};
32
33pub trait DualStackIpExt: Ip {
36 type OtherVersion: DualStackIpExt<OtherVersion = Self>;
38}
39
40impl DualStackIpExt for Ipv4 {
41 type OtherVersion = Ipv6;
42}
43
44impl DualStackIpExt for Ipv6 {
45 type OtherVersion = Ipv4;
46}
47
48pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
50 this_stack: <T as GenericOverIp<I>>::Type,
51 other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
52 _marker: IpVersionMarker<I>,
53}
54
55impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
56 pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
58 where
59 T: GenericOverIp<I, Type = T>,
60 {
61 Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
62 }
63
64 pub fn into_inner(
66 self,
67 ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
68 let Self { this_stack, other_stack, _marker } = self;
69 (this_stack, other_stack)
70 }
71
72 pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
74 self.this_stack
75 }
76
77 pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
79 &self.this_stack
80 }
81
82 pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
84 self.other_stack
85 }
86
87 pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
89 &self.other_stack
90 }
91
92 pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
94 let Self { this_stack, other_stack, _marker } = self;
95 DualStackTuple {
96 this_stack: other_stack,
97 other_stack: this_stack,
98 _marker: IpVersionMarker::new(),
99 }
100 }
101
102 pub fn cast<X>(self) -> DualStackTuple<X, T>
111 where
112 X: DualStackIpExt,
113 T: GenericOverIp<X>
114 + GenericOverIp<X::OtherVersion>
115 + GenericOverIp<Ipv4>
116 + GenericOverIp<Ipv6>,
117 {
118 I::map_ip_in(
119 self,
120 |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
121 |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
122 )
123 }
124}
125
126impl<
127 I: DualStackIpExt,
128 NewIp: DualStackIpExt,
129 T: GenericOverIp<NewIp>
130 + GenericOverIp<NewIp::OtherVersion>
131 + GenericOverIp<I>
132 + GenericOverIp<I::OtherVersion>,
133> GenericOverIp<NewIp> for DualStackTuple<I, T>
134{
135 type Type = DualStackTuple<NewIp, T>;
136}
137
138pub trait SocketIpExt: Ip {
140 const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
142 SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
145 };
146}
147
148impl<I: Ip> SocketIpExt for I {}
149
150#[cfg(test)]
151mod socket_ip_ext_test {
152 use super::*;
153 use ip_test_macro::ip_test;
154
155 #[ip_test(I)]
156 fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
157 let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
162 .expect("loopback address should be a valid SocketIpAddr");
163 }
164}
165
166#[derive(Debug, PartialEq, Eq)]
174pub enum EitherStack<T, O> {
175 ThisStack(T),
177 OtherStack(O),
179}
180
181impl<T, O> Clone for EitherStack<T, O>
182where
183 T: Clone,
184 O: Clone,
185{
186 #[cfg_attr(feature = "instrumented", track_caller)]
187 fn clone(&self) -> Self {
188 match self {
189 Self::ThisStack(t) => Self::ThisStack(t.clone()),
190 Self::OtherStack(t) => Self::OtherStack(t.clone()),
191 }
192 }
193}
194
195#[derive(Debug)]
213#[allow(missing_docs)]
214pub enum MaybeDualStack<DS, NDS> {
215 DualStack(DS),
216 NotDualStack(NDS),
217}
218
219impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
222 for MaybeDualStack<DS, NDS>
223{
224 type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
225}
226
227#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
229#[generic_over_ip()]
230pub enum SetDualStackEnabledError {
231 #[error("a socket can only have dual stack enabled or disabled while unbound")]
233 SocketIsBound,
234 #[error(transparent)]
236 NotCapable(#[from] NotDualStackCapableError),
237}
238
239#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
242#[generic_over_ip()]
243#[error("socket's protocol is not dual-stack capable")]
244pub struct NotDualStackCapableError;
245
246#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
248pub struct Shutdown {
249 pub send: bool,
253 pub receive: bool,
257}
258
259#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
261#[generic_over_ip()]
262pub enum ShutdownType {
263 Send,
265 Receive,
267 SendAndReceive,
269}
270
271impl ShutdownType {
272 pub fn to_send_receive(&self) -> (bool, bool) {
274 match self {
275 Self::Send => (true, false),
276 Self::Receive => (false, true),
277 Self::SendAndReceive => (true, true),
278 }
279 }
280
281 pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
283 match (send, receive) {
284 (true, false) => Some(Self::Send),
285 (false, true) => Some(Self::Receive),
286 (true, true) => Some(Self::SendAndReceive),
287 (false, false) => None,
288 }
289 }
290}
291
292pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
294 fn must_have_zone(&self) -> bool
300 where
301 Self: Copy,
302 {
303 self.try_into_null_zoned().is_some()
304 }
305
306 fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
310 if self.get().is_loopback() {
311 return None;
312 }
313 AddrAndZone::new(self, ())
314 }
315}
316
317impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
318
319pub trait SocketZonedAddrExt<W, A, D> {
321 fn resolve_addr_with_device(
329 self,
330 device: Option<D::Weak>,
331 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
332 where
333 D: StrongDeviceIdentifier;
334}
335
336impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
337where
338 W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
339 A: IpAddress,
340{
341 fn resolve_addr_with_device(
342 self,
343 device: Option<D::Weak>,
344 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
345 where
346 D: StrongDeviceIdentifier,
347 {
348 let (addr, zone) = self.into_addr_zone();
349 let device = match (zone, device) {
350 (Some(zone), Some(device)) => {
351 if device != zone {
352 return Err(ZonedAddressError::DeviceZoneMismatch);
353 }
354 Some(EitherDeviceId::Strong(zone))
355 }
356 (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
357 (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
358 (None, None) => {
359 if addr.as_ref().must_have_zone() {
360 return Err(ZonedAddressError::RequiredZoneNotProvided);
361 } else {
362 None
363 }
364 }
365 };
366 Ok((addr, device))
367 }
368}
369
370pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
376 pub local_ip: Option<&'a SpecifiedAddr<A>>,
378 pub remote_ip: Option<&'a SpecifiedAddr<A>>,
380 pub old_device: Option<&'a D>,
382}
383
384impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
385 pub fn check_update<N>(
388 self,
389 new_device: Option<&N>,
390 ) -> Result<(), SocketDeviceUpdateNotAllowedError>
391 where
392 D: PartialEq<N>,
393 {
394 let Self { local_ip, remote_ip, old_device } = self;
395 let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
396 || remote_ip.is_some_and(|a| a.must_have_zone());
397
398 if !must_have_zone {
399 return Ok(());
400 }
401
402 let old_device = old_device.unwrap_or_else(|| {
403 panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
404 });
405
406 if new_device.is_some_and(|new_device| old_device == new_device) {
407 Ok(())
408 } else {
409 Err(SocketDeviceUpdateNotAllowedError)
410 }
411 }
412}
413
414pub struct SocketDeviceUpdateNotAllowedError;
416
417pub trait SocketMapAddrSpec {
422 type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
424 type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
426}
427
428pub struct ListenerAddrInfo {
430 pub has_device: bool,
432 pub specified_addr: bool,
435}
436
437impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
438 pub(crate) fn info(&self) -> ListenerAddrInfo {
439 let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
440 ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
441 }
442}
443
444pub trait SocketMapStateSpec {
446 type AddrVecTag: Eq + Copy + Debug + 'static;
451
452 fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
454
455 fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
457
458 type ListenerId: Clone + Debug;
460 type ConnId: Clone + Debug;
462
463 type ListenerSharingState: Clone + Debug;
466
467 type ConnSharingState: Clone + Debug;
470
471 type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
473 + Debug;
474
475 type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
477 + Debug;
478}
479
480#[derive(Copy, Clone, Debug, Eq, PartialEq)]
483pub struct IncompatibleError;
484
485pub trait Inserter<T> {
487 fn insert(self, item: T);
492}
493
494impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
495 fn insert(self, item: T) {
496 self.extend([item])
497 }
498}
499
500impl<T> Inserter<T> for Never {
501 fn insert(self, _: T) {
502 match self {}
503 }
504}
505
506pub trait SocketMapAddrStateSpec {
508 type Id;
510
511 type SharingState;
518
519 type Inserter<'a>: Inserter<Self::Id> + 'a
521 where
522 Self: 'a,
523 Self::Id: 'a;
524
525 fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
528
529 fn contains_id(&self, id: &Self::Id) -> bool;
531
532 fn try_get_inserter<'a, 'b>(
540 &'b mut self,
541 new_sharing_state: &'a Self::SharingState,
542 ) -> Result<Self::Inserter<'b>, IncompatibleError>;
543
544 fn could_insert(&self, new_sharing_state: &Self::SharingState)
549 -> Result<(), IncompatibleError>;
550
551 fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
555}
556
557pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
559 fn try_update_sharing(
562 &mut self,
563 id: Self::Id,
564 new_sharing_state: &Self::SharingState,
565 ) -> Result<(), IncompatibleError>;
566}
567
568pub trait SocketMapConflictPolicy<
570 Addr,
571 SharingState,
572 I: Ip,
573 D: DeviceIdentifier,
574 A: SocketMapAddrSpec,
575>: SocketMapStateSpec
576{
577 fn check_insert_conflicts(
586 new_sharing_state: &SharingState,
587 addr: &Addr,
588 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
589 ) -> Result<(), InsertError>;
590}
591
592pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
595 SocketMapConflictPolicy<Addr, SharingState, I, D, A>
596where
597 A: SocketMapAddrSpec,
598{
599 fn allows_sharing_update(
602 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
603 addr: &Addr,
604 old_sharing: &SharingState,
605 new_sharing: &SharingState,
606 ) -> Result<(), UpdateSharingError>;
607}
608
609#[derive(Derivative)]
611#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
612#[allow(missing_docs)]
613pub enum Bound<S: SocketMapStateSpec + ?Sized> {
614 Listen(S::ListenerAddrState),
615 Conn(S::ConnAddrState),
616}
617
618#[derive(Derivative)]
633#[derivative(
634 Debug(bound = "D: Debug"),
635 Clone(bound = "D: Clone"),
636 Eq(bound = "D: Eq"),
637 PartialEq(bound = "D: PartialEq"),
638 Hash(bound = "D: Hash")
639)]
640#[allow(missing_docs)]
641pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
642 Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
643 Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
644}
645
646impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
647 Tagged<AddrVec<I, D, A>> for Bound<S>
648{
649 type Tag = S::AddrVecTag;
650 fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
651 match (self, address) {
652 (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
653 (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
654 S::connected_tag(device.is_some(), c)
655 }
656 (Bound::Listen(_), AddrVec::Conn(_)) => {
657 unreachable!("found listen state for conn addr")
658 }
659 (Bound::Conn(_), AddrVec::Listen(_)) => {
660 unreachable!("found conn state for listen addr")
661 }
662 }
663 }
664}
665
666impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
667 type IterShadows = AddrVecIter<I, D, A>;
668
669 fn iter_shadows(&self) -> Self::IterShadows {
670 let (socket_ip_addr, device) = match self.clone() {
671 AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
672 AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
673 };
674 let mut iter = match device {
675 Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
676 None => AddrVecIter::without_device(socket_ip_addr),
677 };
678 assert_eq!(iter.next().as_ref(), Some(self));
680 iter
681 }
682}
683
684#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
686#[allow(missing_docs)]
687pub enum SocketAddrType {
688 AnyListener,
689 SpecificListener,
690 Connected,
691}
692
693impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
694 fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
695 match addr {
696 Some(_) => SocketAddrType::SpecificListener,
697 None => SocketAddrType::AnyListener,
698 }
699 }
700}
701
702impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
703 fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
704 SocketAddrType::Connected
705 }
706}
707
708pub enum RemoveResult {
710 Success,
712 IsLast,
715}
716
717#[derive(Derivative)]
718#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
719pub enum SocketId<S: SocketMapStateSpec> {
720 Listener(S::ListenerId),
721 Connection(S::ConnId),
722}
723
724#[derive(Derivative)]
738#[derivative(Default(bound = ""))]
739pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
740 addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
741}
742
743impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
744 BoundSocketMap<I, D, A, S>
745{
746 pub fn len(&self) -> usize {
748 self.addr_to_state.len()
749 }
750}
751
752pub enum Listener {}
754pub enum Connection {}
756
757pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
759
760impl<
761 'a,
762 I: Ip,
763 D: DeviceIdentifier,
764 SocketType: ConvertSocketMapState<I, D, A, S>,
765 A: SocketMapAddrSpec,
766 S: SocketMapStateSpec,
767> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
768where
769 S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
770{
771 pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
773 let Self(addr_to_state, _marker) = self;
774 addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
775 SocketType::from_bound_ref(state)
776 .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
777 })
778 }
779
780 pub fn could_insert(
786 self,
787 addr: &SocketType::Addr,
788 sharing: &SocketType::SharingState,
789 ) -> Result<(), InsertError> {
790 let Self(addr_to_state, _) = self;
791 match self.get_by_addr(addr) {
792 Some(state) => {
793 state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
794 }
795 None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
796 }
797 }
798}
799
800#[derive(Derivative)]
802#[derivative(Debug(bound = ""))]
803pub struct SocketStateEntry<
804 'a,
805 I: Ip,
806 D: DeviceIdentifier,
807 A: SocketMapAddrSpec,
808 S: SocketMapStateSpec,
809 SocketType,
810> {
811 id: SocketId<S>,
812 addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
813 _marker: PhantomData<SocketType>,
814}
815
816impl<
817 'a,
818 I: Ip,
819 D: DeviceIdentifier,
820 SocketType: ConvertSocketMapState<I, D, A, S>,
821 A: SocketMapAddrSpec,
822 S: SocketMapStateSpec
823 + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
824> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
825where
826 SocketType::SharingState: Clone,
827 SocketType::Id: Clone,
828{
829 pub fn try_insert(
832 self,
833 socket_addr: SocketType::Addr,
834 tag_state: SocketType::SharingState,
835 id: SocketType::Id,
836 ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, (InsertError, SocketType::SharingState)>
837 {
838 self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
839 .map(|(entry, ())| entry)
840 }
841
842 pub fn try_insert_with<R>(
847 self,
848 socket_addr: SocketType::Addr,
849 tag_state: SocketType::SharingState,
850 make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
851 ) -> Result<
852 (SocketStateEntry<'a, I, D, A, S, SocketType>, R),
853 (InsertError, SocketType::SharingState),
854 > {
855 let Self(addr_to_state, _) = self;
856 match S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state) {
857 Err(e) => return Err((e, tag_state)),
858 Ok(()) => (),
859 };
860
861 let addr = SocketType::to_addr_vec(&socket_addr);
862
863 match addr_to_state.entry(addr) {
864 Entry::Occupied(mut o) => {
865 let (id, ret) = o.map_mut(|bound| {
866 let bound = match SocketType::from_bound_mut(bound) {
867 Some(bound) => bound,
868 None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
869 };
870 match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
871 bound, &tag_state,
872 ) {
873 Ok(v) => {
874 let (id, ret) = make_id(socket_addr, tag_state);
875 v.insert(id.clone());
876 Ok((SocketType::to_socket_id(id), ret))
877 }
878 Err(IncompatibleError) => Err((InsertError::Exists, tag_state)),
879 }
880 })?;
881 Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
882 }
883 Entry::Vacant(v) => {
884 let (id, ret) = make_id(socket_addr, tag_state.clone());
885 let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
886 &tag_state,
887 id.clone(),
888 )));
889 let id = SocketType::to_socket_id(id);
890 Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
891 }
892 }
893 }
894
895 pub fn entry(
897 self,
898 id: &SocketType::Id,
899 addr: &SocketType::Addr,
900 ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
901 let Self(addr_to_state, _) = self;
902 let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
903 Entry::Vacant(_) => return None,
904 Entry::Occupied(o) => o,
905 };
906 let state = SocketType::from_bound_ref(addr_entry.get())?;
907
908 state.contains_id(id).then_some(SocketStateEntry {
909 id: SocketType::to_socket_id(id.clone()),
910 addr_entry,
911 _marker: PhantomData::default(),
912 })
913 }
914
915 pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
917 self.entry(id, addr)
918 .map(|entry| {
919 entry.remove();
920 })
921 .ok_or(NotFoundError)
922 }
923}
924
925#[derive(Debug)]
928pub struct UpdateSharingError;
929
930impl<
931 'a,
932 I: Ip,
933 D: DeviceIdentifier,
934 SocketType: ConvertSocketMapState<I, D, A, S>,
935 A: SocketMapAddrSpec,
936 S: SocketMapStateSpec,
937> SocketStateEntry<'a, I, D, A, S, SocketType>
938where
939 SocketType::Id: Clone,
940{
941 pub fn get_addr(&self) -> &SocketType::Addr {
943 let Self { id: _, addr_entry, _marker } = self;
944 SocketType::from_addr_vec_ref(addr_entry.key())
945 }
946
947 pub fn id(&self) -> &SocketType::Id {
949 let Self { id, addr_entry: _, _marker } = self;
950 SocketType::from_socket_id_ref(id)
951 }
952
953 pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
955 let Self { id, addr_entry, _marker } = self;
956
957 let new_addrvec = SocketType::to_addr_vec(&new_addr);
958 let old_addr = addr_entry.key().clone();
959 let (addr_state, addr_to_state) = addr_entry.remove_from_map();
960 let addr_to_state = match addr_to_state.entry(new_addrvec) {
961 Entry::Occupied(o) => o.into_map(),
962 Entry::Vacant(v) => {
963 if v.descendant_counts().len() != 0 {
964 v.into_map()
965 } else {
966 let new_addr_entry = v.insert(addr_state);
967 return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
968 }
969 }
970 };
971 let to_restore = addr_state;
972 let addr_entry = match addr_to_state.entry(old_addr) {
974 Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
975 Entry::Vacant(v) => v.insert(to_restore),
976 };
977 return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
978 }
979
980 pub fn remove(self) {
982 let Self { id, mut addr_entry, _marker } = self;
983 let addr = addr_entry.key().clone();
984 match addr_entry.map_mut(|value| {
985 let value = match SocketType::from_bound_mut(value) {
986 Some(value) => value,
987 None => unreachable!("found {:?} for address {:?}", value, addr),
988 };
989 value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
990 }) {
991 RemoveResult::Success => (),
992 RemoveResult::IsLast => {
993 let _: Bound<S> = addr_entry.remove();
994 }
995 }
996 }
997
998 pub fn try_update_sharing(
1000 &mut self,
1001 old_sharing_state: &SocketType::SharingState,
1002 new_sharing_state: SocketType::SharingState,
1003 ) -> Result<(), UpdateSharingError>
1004 where
1005 SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1006 S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1007 {
1008 let Self { id, addr_entry, _marker } = self;
1009 let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1010
1011 S::allows_sharing_update(
1012 addr_entry.get_map(),
1013 addr,
1014 old_sharing_state,
1015 &new_sharing_state,
1016 )?;
1017
1018 addr_entry
1019 .map_mut(|value| {
1020 let value = match SocketType::from_bound_mut(value) {
1021 Some(value) => value,
1022 None => unreachable!("found invalid state {:?}", value),
1026 };
1027
1028 value.try_update_sharing(
1029 SocketType::from_socket_id_ref(id).clone(),
1030 &new_sharing_state,
1031 )
1032 })
1033 .map_err(|IncompatibleError| UpdateSharingError)
1034 }
1035}
1036
1037impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1038where
1039 AddrVec<I, D, A>: IterShadows,
1040 S: SocketMapStateSpec,
1041{
1042 pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1044 where
1045 S: SocketMapConflictPolicy<
1046 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1047 <S as SocketMapStateSpec>::ListenerSharingState,
1048 I,
1049 D,
1050 A,
1051 >,
1052 S::ListenerAddrState:
1053 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1054 {
1055 let Self { addr_to_state } = self;
1056 Sockets(addr_to_state, Default::default())
1057 }
1058
1059 pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1061 where
1062 S: SocketMapConflictPolicy<
1063 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1064 <S as SocketMapStateSpec>::ListenerSharingState,
1065 I,
1066 D,
1067 A,
1068 >,
1069 S::ListenerAddrState:
1070 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1071 {
1072 let Self { addr_to_state } = self;
1073 Sockets(addr_to_state, Default::default())
1074 }
1075
1076 pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1078 where
1079 S: SocketMapConflictPolicy<
1080 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1081 <S as SocketMapStateSpec>::ConnSharingState,
1082 I,
1083 D,
1084 A,
1085 >,
1086 S::ConnAddrState:
1087 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1088 {
1089 let Self { addr_to_state } = self;
1090 Sockets(addr_to_state, Default::default())
1091 }
1092
1093 pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1095 where
1096 S: SocketMapConflictPolicy<
1097 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1098 <S as SocketMapStateSpec>::ConnSharingState,
1099 I,
1100 D,
1101 A,
1102 >,
1103 S::ConnAddrState:
1104 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1105 {
1106 let Self { addr_to_state } = self;
1107 Sockets(addr_to_state, Default::default())
1108 }
1109
1110 #[cfg(test)]
1111 pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1112 let Self { addr_to_state } = self;
1113 addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1114 }
1115
1116 pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1118 let Self { addr_to_state } = self;
1119 addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1120 }
1121}
1122
1123pub enum FoundSockets<A, It> {
1125 Single(A),
1127 Multicast(It),
1130}
1131
1132#[allow(missing_docs)]
1134#[derive(Debug)]
1135pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1136 Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1137 Conn(
1138 &'a S::ConnAddrState,
1139 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1140 ),
1141}
1142
1143impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1144where
1145 I: BroadcastIpExt<Addr: MulticastAddress>,
1146 D: DeviceIdentifier,
1147 A: SocketMapAddrSpec,
1148 S: SocketMapStateSpec
1149 + SocketMapConflictPolicy<
1150 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1151 <S as SocketMapStateSpec>::ListenerSharingState,
1152 I,
1153 D,
1154 A,
1155 > + SocketMapConflictPolicy<
1156 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1157 <S as SocketMapStateSpec>::ConnSharingState,
1158 I,
1159 D,
1160 A,
1161 >,
1162{
1163 pub fn lookup_connected(
1169 &self,
1170 (src_ip, src_port): (SocketIpAddr<I::Addr>, A::RemoteIdentifier),
1171 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1172 device: D,
1173 ) -> Option<&'_ S::ConnAddrState> {
1174 let mut addr = ConnAddr {
1175 ip: ConnIpAddr { local: (dst_ip, dst_port), remote: (src_ip, src_port) },
1176 device: Some(device),
1177 };
1178 let entry = self.conns().get_by_addr(&addr);
1179 if entry.is_some() {
1180 return entry;
1181 }
1182 addr.device = None;
1183 self.conns().get_by_addr(&addr)
1184 }
1185
1186 pub fn iter_receivers(
1192 &self,
1193 (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1194 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1195 device: D,
1196 broadcast: Option<I::BroadcastMarker>,
1197 ) -> Option<
1198 FoundSockets<
1199 AddrEntry<'_, I, D, A, S>,
1200 impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1201 >,
1202 > {
1203 let mut matching_entries = AddrVecIter::with_device(
1204 match (src_ip, src_port) {
1205 (Some(specified_src_ip), Some(src_port)) => {
1206 ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1207 .into()
1208 }
1209 _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1210 },
1211 device,
1212 )
1213 .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1214 AddrVec::Listen(l) => {
1215 self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1216 }
1217 AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1218 });
1219
1220 if broadcast.is_some() || dst_ip.addr().is_multicast() {
1221 Some(FoundSockets::Multicast(matching_entries))
1222 } else {
1223 let single_entry: Option<_> = matching_entries.next();
1224 single_entry.map(FoundSockets::Single)
1225 }
1226 }
1227}
1228
1229#[derive(Debug, Eq, PartialEq)]
1231pub enum InsertError {
1232 ShadowAddrExists,
1234 Exists,
1236 ShadowerExists,
1238 IndirectConflict,
1240}
1241
1242pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1245 type Id;
1246 type SharingState;
1247 type Addr: Debug;
1248 type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1249
1250 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1251 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1252 fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1253 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1254 fn to_bound(state: Self::AddrState) -> Bound<S>;
1255 fn to_socket_id(id: Self::Id) -> SocketId<S>;
1256 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1257}
1258
1259impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1260 ConvertSocketMapState<I, D, A, S> for Listener
1261{
1262 type Id = S::ListenerId;
1263 type SharingState = S::ListenerSharingState;
1264 type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1265 type AddrState = S::ListenerAddrState;
1266 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1267 AddrVec::Listen(addr.clone())
1268 }
1269
1270 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1271 match addr {
1272 AddrVec::Listen(l) => l,
1273 AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1274 }
1275 }
1276
1277 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1278 match bound {
1279 Bound::Listen(l) => Some(l),
1280 Bound::Conn(_c) => None,
1281 }
1282 }
1283
1284 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1285 match bound {
1286 Bound::Listen(l) => Some(l),
1287 Bound::Conn(_c) => None,
1288 }
1289 }
1290
1291 fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1292 Bound::Listen(state)
1293 }
1294 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1295 match id {
1296 SocketId::Listener(id) => id,
1297 SocketId::Connection(_) => unreachable!("connection ID for listener"),
1298 }
1299 }
1300 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1301 SocketId::Listener(id)
1302 }
1303}
1304
1305impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1306 ConvertSocketMapState<I, D, A, S> for Connection
1307{
1308 type Id = S::ConnId;
1309 type SharingState = S::ConnSharingState;
1310 type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1311 type AddrState = S::ConnAddrState;
1312 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1313 AddrVec::Conn(addr.clone())
1314 }
1315
1316 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1317 match addr {
1318 AddrVec::Conn(c) => c,
1319 AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1320 }
1321 }
1322
1323 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1324 match bound {
1325 Bound::Listen(_l) => None,
1326 Bound::Conn(c) => Some(c),
1327 }
1328 }
1329
1330 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1331 match bound {
1332 Bound::Listen(_l) => None,
1333 Bound::Conn(c) => Some(c),
1334 }
1335 }
1336
1337 fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1338 Bound::Conn(state)
1339 }
1340
1341 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1342 match id {
1343 SocketId::Connection(id) => id,
1344 SocketId::Listener(_) => unreachable!("listener ID for connection"),
1345 }
1346 }
1347 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1348 SocketId::Connection(id)
1349 }
1350}
1351
1352#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)]
1354pub struct SharingDomain(u64);
1355
1356impl SharingDomain {
1357 pub const fn new(id: u64) -> Self {
1361 SharingDomain(id)
1362 }
1363}
1364
1365#[derive(Default, Debug, Eq, PartialEq, Clone, Copy, Hash)]
1368pub enum ReusePortOption {
1369 #[default]
1371 Disabled,
1372
1373 Enabled(SharingDomain),
1376}
1377
1378impl ReusePortOption {
1379 pub fn is_enabled(&self) -> bool {
1381 matches!(self, ReusePortOption::Enabled(_))
1382 }
1383
1384 pub fn is_shareable_with(&self, other: &Self) -> bool {
1387 match (self, other) {
1388 (ReusePortOption::Enabled(domain1), ReusePortOption::Enabled(domain2)) => {
1389 domain1 == domain2
1390 }
1391 _ => false,
1392 }
1393 }
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use alloc::vec;
1399 use alloc::vec::Vec;
1400
1401 use assert_matches::assert_matches;
1402 use net_declare::{net_ip_v4, net_ip_v6};
1403 use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1404 use netstack3_hashmap::HashSet;
1405 use test_case::test_case;
1406
1407 use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1408 use crate::testutil::set_logger_for_test;
1409
1410 use super::*;
1411
1412 #[test_case(net_ip_v4!("8.8.8.8"))]
1413 #[test_case(net_ip_v4!("127.0.0.1"))]
1414 #[test_case(net_ip_v4!("127.0.8.9"))]
1415 #[test_case(net_ip_v4!("224.1.2.3"))]
1416 fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1417 let addr = SpecifiedAddr::new(addr).unwrap();
1419 assert_eq!(addr.must_have_zone(), false);
1420 }
1421
1422 #[test_case(net_ip_v6!("1::2:3"), false)]
1423 #[test_case(net_ip_v6!("::1"), false; "localhost")]
1424 #[test_case(net_ip_v6!("1::"), false)]
1425 #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1426 #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1427 #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1428 #[test_case(net_ip_v6!("fe80::1"), true)]
1429 fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1430 let addr = SpecifiedAddr::new(addr).unwrap();
1433 assert_eq!(addr.must_have_zone(), must_have);
1434 }
1435
1436 #[test]
1437 fn try_into_null_zoned_ipv6() {
1438 assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1439 let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1440 const ZONE: u32 = 5;
1441 assert_eq!(
1442 zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1443 Some(AddrAndZone::new(zoned, ZONE).unwrap())
1444 );
1445 }
1446
1447 enum FakeSpec {}
1448
1449 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1450 struct Listener(usize);
1451
1452 #[derive(PartialEq, Eq, Debug, Copy, Clone)]
1453 struct SharingState {
1454 tag: char,
1455 shared: bool,
1456 }
1457
1458 impl SharingState {
1459 fn exclusive(tag: char) -> Self {
1460 Self { tag, shared: false }
1461 }
1462
1463 fn shared(tag: char) -> Self {
1464 Self { tag, shared: true }
1465 }
1466 }
1467
1468 impl SharingState {
1469 fn can_share_with(&self, other: &Self) -> bool {
1470 self.tag == other.tag && self.shared && other.shared
1471 }
1472 }
1473
1474 #[derive(PartialEq, Eq, Debug)]
1475 struct Multiple<T> {
1476 sharing_state: SharingState,
1477 entries: Vec<T>,
1478 }
1479
1480 impl<T> Multiple<T> {
1481 fn new_exclusive(tag: char, entries: Vec<T>) -> Self {
1482 Self { sharing_state: SharingState { tag, shared: false }, entries }
1483 }
1484 }
1485
1486 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1487 struct Conn(usize);
1488
1489 enum FakeAddrSpec {}
1490
1491 impl SocketMapAddrSpec for FakeAddrSpec {
1492 type LocalIdentifier = NonZeroU16;
1493 type RemoteIdentifier = ();
1494 }
1495
1496 impl SocketMapStateSpec for FakeSpec {
1497 type AddrVecTag = SharingState;
1498
1499 type ListenerId = Listener;
1500 type ConnId = Conn;
1501
1502 type ListenerSharingState = SharingState;
1503 type ConnSharingState = SharingState;
1504
1505 type ListenerAddrState = Multiple<Listener>;
1506 type ConnAddrState = Multiple<Conn>;
1507
1508 fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1509 state.sharing_state
1510 }
1511
1512 fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1513 state.sharing_state
1514 }
1515 }
1516
1517 type FakeBoundSocketMap =
1518 BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1519
1520 #[derive(Default)]
1524 struct FakeSocketIdGen {
1525 next_id: usize,
1526 }
1527
1528 impl FakeSocketIdGen {
1529 fn next(&mut self) -> usize {
1530 let next_next_id = self.next_id + 1;
1531 core::mem::replace(&mut self.next_id, next_next_id)
1532 }
1533 }
1534
1535 impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1536 type Id = I;
1537 type SharingState = SharingState;
1538 type Inserter<'a>
1539 = &'a mut Vec<I>
1540 where
1541 I: 'a;
1542
1543 fn new(sharing_state: &SharingState, id: I) -> Self {
1544 Self { sharing_state: *sharing_state, entries: vec![id] }
1545 }
1546
1547 fn contains_id(&self, id: &Self::Id) -> bool {
1548 self.entries.contains(id)
1549 }
1550
1551 fn try_get_inserter<'a, 'b>(
1552 &'b mut self,
1553 new_sharing_state: &'a SharingState,
1554 ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1555 (self.sharing_state == *new_sharing_state)
1556 .then_some(&mut self.entries)
1557 .ok_or(IncompatibleError)
1558 }
1559
1560 fn could_insert(&self, new_sharing_state: &SharingState) -> Result<(), IncompatibleError> {
1561 (self.sharing_state == *new_sharing_state).then_some(()).ok_or(IncompatibleError)
1562 }
1563
1564 fn remove_by_id(&mut self, id: I) -> RemoveResult {
1565 let index = self.entries.iter().position(|i| i == &id).expect("did not find id");
1566 let _: I = self.entries.swap_remove(index);
1567 if self.entries.is_empty() { RemoveResult::IsLast } else { RemoveResult::Success }
1568 }
1569 }
1570
1571 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1572 SocketMapConflictPolicy<A, SharingState, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1573 for FakeSpec
1574 {
1575 fn check_insert_conflicts(
1576 new_sharing_state: &SharingState,
1577 addr: &A,
1578 socketmap: &SocketMap<
1579 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1580 Bound<FakeSpec>,
1581 >,
1582 ) -> Result<(), InsertError> {
1583 let dest: AddrVec<_, _, _> = addr.clone().into();
1584 if dest.iter_shadows().any(|a| {
1585 let entry = socketmap.get(&a);
1586 match entry {
1587 Some(Bound::Listen(Multiple { sharing_state, .. }))
1588 | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1589 !sharing_state.can_share_with(new_sharing_state)
1590 }
1591 None => false,
1592 }
1593 }) {
1594 return Err(InsertError::ShadowAddrExists);
1595 }
1596
1597 match socketmap.get(&dest) {
1598 Some(Bound::Listen(Multiple { sharing_state, .. }))
1599 | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1600 if sharing_state != new_sharing_state {
1603 return Err(InsertError::Exists);
1604 }
1605 }
1606 None => (),
1607 }
1608
1609 if socketmap
1610 .descendant_counts(&dest)
1611 .any(|(sharing_state, _count)| !sharing_state.can_share_with(new_sharing_state))
1612 {
1613 Err(InsertError::ShadowerExists)
1614 } else {
1615 Ok(())
1616 }
1617 }
1618 }
1619
1620 impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1621 fn try_update_sharing(
1622 &mut self,
1623 id: Self::Id,
1624 new_sharing_state: &Self::SharingState,
1625 ) -> Result<(), IncompatibleError> {
1626 if self.sharing_state == *new_sharing_state {
1627 return Ok(());
1628 }
1629
1630 if self.entries.len() != 1 {
1635 return Err(IncompatibleError);
1636 }
1637 assert!(self.entries.contains(&id));
1638 self.sharing_state = *new_sharing_state;
1639 Ok(())
1640 }
1641 }
1642
1643 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1644 SocketMapUpdateSharingPolicy<
1645 A,
1646 SharingState,
1647 Ipv4,
1648 FakeWeakDeviceId<FakeDeviceId>,
1649 FakeAddrSpec,
1650 > for FakeSpec
1651 {
1652 fn allows_sharing_update(
1653 _socketmap: &SocketMap<
1654 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1655 Bound<Self>,
1656 >,
1657 _addr: &A,
1658 _old_sharing: &SharingState,
1659 _new_sharing_state: &SharingState,
1660 ) -> Result<(), UpdateSharingError> {
1661 Ok(())
1662 }
1663 }
1664
1665 const LISTENER_ADDR: ListenerAddr<
1666 ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1667 FakeWeakDeviceId<FakeDeviceId>,
1668 > = ListenerAddr {
1669 ip: ListenerIpAddr {
1670 addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1671 identifier: NonZeroU16::new(1).unwrap(),
1672 },
1673 device: None,
1674 };
1675
1676 const CONN_ADDR: ConnAddr<
1677 ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1678 FakeWeakDeviceId<FakeDeviceId>,
1679 > = ConnAddr {
1680 ip: ConnIpAddr {
1681 local: (
1682 unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1683 NonZeroU16::new(1).unwrap(),
1684 ),
1685 remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1686 },
1687 device: None,
1688 };
1689
1690 #[test]
1691 fn bound_insert_get_remove_listener() {
1692 set_logger_for_test();
1693 let mut bound = FakeBoundSocketMap::default();
1694 let mut fake_id_gen = FakeSocketIdGen::default();
1695
1696 let addr = LISTENER_ADDR;
1697
1698 let id = {
1699 let entry = bound
1700 .listeners_mut()
1701 .try_insert(addr, SharingState::exclusive('v'), Listener(fake_id_gen.next()))
1702 .unwrap();
1703 assert_eq!(entry.get_addr(), &addr);
1704 entry.id().clone()
1705 };
1706
1707 assert_eq!(
1708 bound.listeners().get_by_addr(&addr),
1709 Some(&Multiple::new_exclusive('v', vec![id]))
1710 );
1711
1712 assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1713 assert_eq!(bound.listeners().get_by_addr(&addr), None);
1714 }
1715
1716 #[test]
1717 fn bound_insert_get_remove_conn() {
1718 set_logger_for_test();
1719 let mut bound = FakeBoundSocketMap::default();
1720 let mut fake_id_gen = FakeSocketIdGen::default();
1721
1722 let addr = CONN_ADDR;
1723
1724 let id = {
1725 let entry = bound
1726 .conns_mut()
1727 .try_insert(addr, SharingState::exclusive('v'), Conn(fake_id_gen.next()))
1728 .unwrap();
1729 assert_eq!(entry.get_addr(), &addr);
1730 entry.id().clone()
1731 };
1732
1733 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('v', vec![id])));
1734
1735 assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1736 assert_eq!(bound.conns().get_by_addr(&addr), None);
1737 }
1738
1739 #[test]
1740 fn bound_iter_addrs() {
1741 set_logger_for_test();
1742 let mut bound = FakeBoundSocketMap::default();
1743 let mut fake_id_gen = FakeSocketIdGen::default();
1744
1745 let listener_addrs = [
1746 (Some(net_ip_v4!("1.1.1.1")), 1),
1747 (Some(net_ip_v4!("2.2.2.2")), 2),
1748 (Some(net_ip_v4!("1.1.1.1")), 3),
1749 (None, 4),
1750 ]
1751 .map(|(ip, identifier)| ListenerAddr {
1752 device: None,
1753 ip: ListenerIpAddr {
1754 addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1755 identifier: NonZeroU16::new(identifier).unwrap(),
1756 },
1757 });
1758 let conn_addrs = [
1759 (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1760 (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1761 ]
1762 .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1763 ip: ConnIpAddr {
1764 local: (
1765 SocketIpAddr::new(local_ip).unwrap(),
1766 NonZeroU16::new(local_identifier).unwrap(),
1767 ),
1768 remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1769 },
1770 device: None,
1771 });
1772
1773 for addr in listener_addrs.iter().cloned() {
1774 let _entry = bound
1775 .listeners_mut()
1776 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1777 .unwrap();
1778 }
1779 for addr in conn_addrs.iter().cloned() {
1780 let _entry = bound
1781 .conns_mut()
1782 .try_insert(addr, SharingState::exclusive('a'), Conn(fake_id_gen.next()))
1783 .unwrap();
1784 }
1785 let expected_addrs = listener_addrs
1786 .into_iter()
1787 .map(Into::into)
1788 .chain(conn_addrs.into_iter().map(Into::into))
1789 .collect::<HashSet<_>>();
1790
1791 assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1792 }
1793
1794 #[test]
1795 fn try_insert_with_callback_not_called_on_error() {
1796 set_logger_for_test();
1799 let mut bound = FakeBoundSocketMap::default();
1800 let addr = LISTENER_ADDR;
1801
1802 let _: &Listener = bound
1804 .listeners_mut()
1805 .try_insert(addr, SharingState::exclusive('a'), Listener(0))
1806 .unwrap()
1807 .id();
1808
1809 fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1813 panic!("should never be called");
1814 }
1815
1816 assert_matches!(
1817 bound.listeners_mut().try_insert_with(
1818 addr,
1819 SharingState::exclusive('b'),
1820 is_never_called
1821 ),
1822 Err((InsertError::Exists, SharingState { .. }))
1823 );
1824 assert_matches!(
1825 bound.listeners_mut().try_insert_with(
1826 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1827 SharingState::exclusive('b'),
1828 is_never_called
1829 ),
1830 Err((InsertError::ShadowAddrExists, _))
1831 );
1832 assert_matches!(
1833 bound.conns_mut().try_insert_with(
1834 ConnAddr {
1835 device: None,
1836 ip: ConnIpAddr {
1837 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1838 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1839 },
1840 },
1841 SharingState::exclusive('b'),
1842 is_never_called,
1843 ),
1844 Err((InsertError::ShadowAddrExists, _))
1845 );
1846 }
1847
1848 #[test]
1849 fn insert_listener_conflict_with_listener() {
1850 set_logger_for_test();
1851 let mut bound = FakeBoundSocketMap::default();
1852 let mut fake_id_gen = FakeSocketIdGen::default();
1853 let addr = LISTENER_ADDR;
1854
1855 let _: &Listener = bound
1856 .listeners_mut()
1857 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1858 .unwrap()
1859 .id();
1860 assert_matches!(
1861 bound.listeners_mut().try_insert(
1862 addr,
1863 SharingState::exclusive('b'),
1864 Listener(fake_id_gen.next())
1865 ),
1866 Err((InsertError::Exists, SharingState { tag: 'b', .. }))
1867 );
1868 }
1869
1870 #[test]
1871 fn insert_listener_conflict_with_shadower() {
1872 set_logger_for_test();
1873 let mut bound = FakeBoundSocketMap::default();
1874 let mut fake_id_gen = FakeSocketIdGen::default();
1875 let addr = LISTENER_ADDR;
1876 let shadows_addr = {
1877 assert_eq!(addr.device, None);
1878 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1879 };
1880
1881 let _: &Listener = bound
1882 .listeners_mut()
1883 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1884 .unwrap()
1885 .id();
1886 assert_matches!(
1887 bound.listeners_mut().try_insert(
1888 shadows_addr,
1889 SharingState::exclusive('b'),
1890 Listener(fake_id_gen.next())
1891 ),
1892 Err((InsertError::ShadowAddrExists, SharingState { tag: 'b', .. }))
1893 );
1894 }
1895
1896 #[test]
1897 fn insert_conn_conflict_with_listener() {
1898 set_logger_for_test();
1899 let mut bound = FakeBoundSocketMap::default();
1900 let mut fake_id_gen = FakeSocketIdGen::default();
1901 let addr = LISTENER_ADDR;
1902 let shadows_addr = ConnAddr {
1903 device: None,
1904 ip: ConnIpAddr {
1905 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1906 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1907 },
1908 };
1909
1910 let _: &Listener = bound
1911 .listeners_mut()
1912 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1913 .unwrap()
1914 .id();
1915 assert_matches!(
1916 bound.conns_mut().try_insert(
1917 shadows_addr,
1918 SharingState::exclusive('b'),
1919 Conn(fake_id_gen.next())
1920 ),
1921 Err((InsertError::ShadowAddrExists, SharingState { tag: 'b', .. }))
1922 );
1923 }
1924
1925 #[test]
1926 fn insert_and_remove_listener() {
1927 set_logger_for_test();
1928 let mut bound = FakeBoundSocketMap::default();
1929 let mut fake_id_gen = FakeSocketIdGen::default();
1930 let addr = LISTENER_ADDR;
1931
1932 let a = bound
1933 .listeners_mut()
1934 .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1935 .unwrap()
1936 .id()
1937 .clone();
1938 let b = bound
1939 .listeners_mut()
1940 .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1941 .unwrap()
1942 .id()
1943 .clone();
1944 assert_ne!(a, b);
1945
1946 assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
1947 assert_eq!(
1948 bound.listeners().get_by_addr(&addr),
1949 Some(&Multiple::new_exclusive('x', vec![b]))
1950 );
1951 }
1952
1953 #[test]
1954 fn insert_and_remove_conn() {
1955 set_logger_for_test();
1956 let mut bound = FakeBoundSocketMap::default();
1957 let mut fake_id_gen = FakeSocketIdGen::default();
1958 let addr = CONN_ADDR;
1959
1960 let a = bound
1961 .conns_mut()
1962 .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1963 .unwrap()
1964 .id()
1965 .clone();
1966 let b = bound
1967 .conns_mut()
1968 .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1969 .unwrap()
1970 .id()
1971 .clone();
1972 assert_ne!(a, b);
1973
1974 assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
1975 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('x', vec![b])));
1976 }
1977
1978 #[test]
1979 fn update_listener_to_shadowed_addr_fails() {
1980 let mut bound = FakeBoundSocketMap::default();
1981 let mut fake_id_gen = FakeSocketIdGen::default();
1982
1983 let first_addr = LISTENER_ADDR;
1984 let second_addr = ListenerAddr {
1985 ip: ListenerIpAddr {
1986 addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
1987 ..LISTENER_ADDR.ip
1988 },
1989 ..LISTENER_ADDR
1990 };
1991 let both_shadow = ListenerAddr {
1992 ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
1993 device: None,
1994 };
1995
1996 let first = bound
1997 .listeners_mut()
1998 .try_insert(first_addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1999 .unwrap()
2000 .id()
2001 .clone();
2002 let second = bound
2003 .listeners_mut()
2004 .try_insert(second_addr, SharingState::exclusive('b'), Listener(fake_id_gen.next()))
2005 .unwrap()
2006 .id()
2007 .clone();
2008
2009 let (ExistsError, entry) = bound
2012 .listeners_mut()
2013 .entry(&second, &second_addr)
2014 .unwrap()
2015 .try_update_addr(both_shadow)
2016 .expect_err("update should fail");
2017
2018 assert_eq!(entry.id(), &second);
2020 drop(entry);
2021
2022 let (ExistsError, entry) = bound
2023 .listeners_mut()
2024 .entry(&first, &first_addr)
2025 .unwrap()
2026 .try_update_addr(both_shadow)
2027 .expect_err("update should fail");
2028 assert_eq!(entry.get_addr(), &first_addr);
2029 }
2030
2031 #[test]
2032 fn nonexistent_conn_entry() {
2033 let mut map = FakeBoundSocketMap::default();
2034 let mut fake_id_gen = FakeSocketIdGen::default();
2035 let addr = CONN_ADDR;
2036 let conn_id = map
2037 .conns_mut()
2038 .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2039 .expect("failed to insert")
2040 .id()
2041 .clone();
2042 assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
2043
2044 assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
2045 }
2046
2047 #[test]
2048 fn update_conn_sharing() {
2049 let mut map = FakeBoundSocketMap::default();
2050 let mut fake_id_gen = FakeSocketIdGen::default();
2051 let addr = CONN_ADDR;
2052 let mut entry = map
2053 .conns_mut()
2054 .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2055 .expect("failed to insert");
2056
2057 entry
2058 .try_update_sharing(&SharingState::exclusive('a'), SharingState::exclusive('d'))
2059 .expect("worked");
2060 let mut second_conn = map
2063 .conns_mut()
2064 .try_insert(addr.clone(), SharingState::exclusive('d'), Conn(fake_id_gen.next()))
2065 .expect("can insert");
2066 assert_matches!(
2067 second_conn
2068 .try_update_sharing(&SharingState::exclusive('d'), SharingState::exclusive('e')),
2069 Err(UpdateSharingError)
2070 );
2071 }
2072
2073 #[test]
2074 fn lookup_connected() {
2075 let mut map = FakeBoundSocketMap::default();
2076 let mut fake_id_gen = FakeSocketIdGen::default();
2077
2078 let sharing_state = SharingState::shared('a');
2079
2080 let device_id = FakeWeakDeviceId(FakeDeviceId);
2081 let entry1 = map
2082 .conns_mut()
2083 .try_insert(CONN_ADDR, sharing_state, Conn(fake_id_gen.next()))
2084 .expect("failed to insert")
2085 .id()
2086 .clone();
2087 let conn = map
2088 .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2089 .expect("lookup should succeed");
2090 assert!(conn.contains_id(&entry1));
2091
2092 let addr_with_device = ConnAddr { device: Some(device_id), ..CONN_ADDR };
2095 let entry2 = map
2096 .conns_mut()
2097 .try_insert(addr_with_device, sharing_state, Conn(fake_id_gen.next()))
2098 .expect("failed to insert")
2099 .id()
2100 .clone();
2101 let conn = map
2102 .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2103 .expect("lookup should succeed");
2104 assert!(conn.contains_id(&entry2));
2105 }
2106}