netstack3_base/
matchers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Trait definition for matchers.
6
7use alloc::format;
8use alloc::string::String;
9use core::convert::Infallible as Never;
10use core::fmt::Debug;
11use core::num::NonZeroU64;
12use core::ops::RangeInclusive;
13
14use bitflags::bitflags;
15use derivative::Derivative;
16use net_types::ip::{IpAddr, IpAddress, Ipv4Addr, Ipv6Addr, Subnet};
17
18use crate::{InspectableValue, Inspector, Mark, MarkDomain, MarkStorage, Marks};
19
20/// Trait defining required types for matchers provided by bindings.
21///
22/// Allows rules that match on device class to be installed, storing the
23/// [`MatcherBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
24/// Core to have Bindings provide the type since it is platform-specific.
25pub trait MatcherBindingsTypes {
26    /// The device class type for devices installed in the netstack.
27    type DeviceClass: Clone + Debug;
28}
29
30/// Common pattern to define a matcher for a metadata input `T`.
31///
32/// Used in matching engines like filtering and routing rules.
33pub trait Matcher<T> {
34    /// Returns whether the provided value matches.
35    fn matches(&self, actual: &T) -> bool;
36
37    /// Returns whether the provided value is set and matches.
38    fn required_matches(&self, actual: Option<&T>) -> bool {
39        actual.map_or(false, |actual| self.matches(actual))
40    }
41}
42
43/// Implement `Matcher` for optional matchers, so that if a matcher is left
44/// unspecified, it matches all inputs by default.
45impl<T, O> Matcher<T> for Option<O>
46where
47    O: Matcher<T>,
48{
49    fn matches(&self, actual: &T) -> bool {
50        self.as_ref().map_or(true, |expected| expected.matches(actual))
51    }
52
53    fn required_matches(&self, actual: Option<&T>) -> bool {
54        self.as_ref().map_or(true, |expected| expected.required_matches(actual))
55    }
56}
57
58/// Matcher that matches IP addresses in a subnet.
59#[derive(Debug, Copy, Clone, PartialEq, Eq)]
60pub struct SubnetMatcher<A: IpAddress>(pub Subnet<A>);
61
62impl<A: IpAddress> Matcher<A> for SubnetMatcher<A> {
63    fn matches(&self, actual: &A) -> bool {
64        let Self(matcher) = self;
65        matcher.contains(actual)
66    }
67}
68
69/// A matcher for network interfaces.
70#[derive(Clone, Derivative, PartialEq, Eq)]
71#[derivative(Debug)]
72pub enum InterfaceMatcher<DeviceClass> {
73    /// The ID of the interface as assigned by the netstack.
74    Id(NonZeroU64),
75    /// Match based on name.
76    Name(String),
77    /// The device class of the interface.
78    DeviceClass(DeviceClass),
79}
80
81impl<DeviceClass: Debug> InspectableValue for InterfaceMatcher<DeviceClass> {
82    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
83        match self {
84            InterfaceMatcher::Id(id) => inspector.record_string(name, format!("Id({})", id.get())),
85            InterfaceMatcher::Name(iface_name) => {
86                inspector.record_string(name, format!("Name({iface_name})"))
87            }
88            InterfaceMatcher::DeviceClass(class) => {
89                inspector.record_debug(name, format!("Class({class:?})"))
90            }
91        };
92    }
93}
94
95/// Allows code to match on properties of an interface (ID, name, and device
96/// class) without Netstack3 Core (or Bindings, in the case of the device class)
97/// having to specifically expose that state.
98pub trait InterfaceProperties<DeviceClass> {
99    /// Returns whether the provided ID matches the interface.
100    fn id_matches(&self, id: &NonZeroU64) -> bool;
101
102    /// Returns whether the provided name matches the interface.
103    fn name_matches(&self, name: &str) -> bool;
104
105    /// Returns whether the provided device class matches the interface.
106    fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
107}
108
109impl<DeviceClass, I: InterfaceProperties<DeviceClass>> Matcher<I>
110    for InterfaceMatcher<DeviceClass>
111{
112    fn matches(&self, actual: &I) -> bool {
113        match self {
114            InterfaceMatcher::Id(id) => actual.id_matches(id),
115            InterfaceMatcher::Name(name) => actual.name_matches(name),
116            InterfaceMatcher::DeviceClass(device_class) => {
117                actual.device_class_matches(device_class)
118            }
119        }
120    }
121}
122
123/// Matcher for the bound device of locally generated traffic.
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum BoundInterfaceMatcher<DeviceClass> {
126    /// The packet is bound to a device which is matched by the matcher.
127    Bound(InterfaceMatcher<DeviceClass>),
128    /// There is no bound device.
129    Unbound,
130}
131
132impl<'a, DeviceClass, D: InterfaceProperties<DeviceClass>> Matcher<Option<&'a D>>
133    for BoundInterfaceMatcher<DeviceClass>
134{
135    fn matches(&self, actual: &Option<&'a D>) -> bool {
136        match self {
137            BoundInterfaceMatcher::Bound(matcher) => matcher.required_matches(actual.as_deref()),
138            BoundInterfaceMatcher::Unbound => actual.is_none(),
139        }
140    }
141}
142
143impl<DeviceClass: Debug> InspectableValue for BoundInterfaceMatcher<DeviceClass> {
144    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
145        match self {
146            BoundInterfaceMatcher::Unbound => inspector.record_str(name, "Unbound"),
147            BoundInterfaceMatcher::Bound(interface) => {
148                inspector.record_inspectable_value(name, interface)
149            }
150        }
151    }
152}
153
154/// A matcher to the socket mark.
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub enum MarkMatcher {
157    /// Matches a packet if it is unmarked.
158    Unmarked,
159    /// The packet carries a mark that is in the range after masking.
160    Marked {
161        /// The mask to apply.
162        mask: u32,
163        /// Start of the range, inclusive.
164        start: u32,
165        /// End of the range, inclusive.
166        end: u32,
167        /// Inverts the meaning of the match.
168        invert: bool,
169    },
170}
171
172impl Matcher<Mark> for MarkMatcher {
173    fn matches(&self, Mark(actual): &Mark) -> bool {
174        match self {
175            MarkMatcher::Unmarked => actual.is_none(),
176            MarkMatcher::Marked { mask, start, end, invert } => {
177                let val = actual.is_some_and(|actual| (*start..=*end).contains(&(actual & *mask)));
178
179                if *invert { !val } else { val }
180            }
181        }
182    }
183}
184
185/// A matcher for the mark in a specific domain..
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187pub struct MarkInDomainMatcher {
188    /// The domain of the mark to match.
189    pub domain: MarkDomain,
190    /// The matcher for the mark.
191    pub matcher: MarkMatcher,
192}
193
194/// The 2 mark matchers a rule can specify. All non-none markers must match.
195#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
196pub struct MarkMatchers(MarkStorage<Option<MarkMatcher>>);
197
198impl MarkMatchers {
199    /// Creates [`MarkMatcher`]s from an iterator of `(MarkDomain, MarkMatcher)`.
200    ///
201    /// An unspecified domain will not have a matcher.
202    ///
203    /// # Panics
204    ///
205    /// Panics if the same domain is specified more than once.
206    pub fn new(matchers: impl IntoIterator<Item = (MarkDomain, MarkMatcher)>) -> Self {
207        MarkMatchers(MarkStorage::new(matchers))
208    }
209
210    /// Returns an iterator over the mark matchers of all domains.
211    pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &Option<MarkMatcher>)> {
212        let Self(storage) = self;
213        storage.iter()
214    }
215}
216
217impl Matcher<Marks> for MarkMatchers {
218    fn matches(&self, actual: &Marks) -> bool {
219        let Self(matchers) = self;
220        matchers.zip_with(actual).all(|(_domain, matcher, actual)| matcher.matches(actual))
221    }
222}
223
224/// A matcher for a socket's cookie.
225pub struct SocketCookieMatcher {
226    /// The cookie to check against.
227    pub cookie: u64,
228    /// Invert the matching criterion (i.e. if the socket cookie isn't the same,
229    /// it matches).
230    pub invert: bool,
231}
232
233impl Matcher<u64> for SocketCookieMatcher {
234    fn matches(&self, actual: &u64) -> bool {
235        let val = *actual == self.cookie;
236        if self.invert { !val } else { val }
237    }
238}
239
240/// A matcher for transport-layer port numbers.
241#[derive(Clone, Debug)]
242pub struct PortMatcher {
243    /// The range of port numbers in which the tested port number must fall.
244    pub range: RangeInclusive<u16>,
245    /// Whether to check for an "inverse" or "negative" match (in which case,
246    /// if the matcher criteria do *not* apply, it *is* considered a match, and
247    /// vice versa).
248    pub invert: bool,
249}
250
251impl Matcher<u16> for PortMatcher {
252    fn matches(&self, actual: &u16) -> bool {
253        let Self { range, invert } = self;
254        range.contains(actual) ^ *invert
255    }
256}
257
258bitflags! {
259    /// A matcher for TCP state machine state.
260    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
261    pub struct TcpStateMatcher: u32 {
262        /// The TCP ESTABLISHED state.
263        const ESTABLISHED = 1 << 0;
264        /// The TCP SYN_SENT state.
265        const SYN_SENT = 1 << 1;
266        /// The TCP SYN_RECV state.
267        const SYN_RECV = 1 << 2;
268        /// The TCP FIN_WAIT1 state.
269        const FIN_WAIT1 = 1 << 3;
270        /// The TCP FIN_WAIT2 state.
271        const FIN_WAIT2 = 1 << 4;
272        /// The TCP TIME_WAIT state.
273        const TIME_WAIT = 1 << 5;
274        /// The TCP CLOSE state.
275        const CLOSE = 1 << 6;
276        /// The TCP CLOSE_WAIT state.
277        const CLOSE_WAIT = 1 << 7;
278        /// The TCP LAST_ACK state.
279        const LAST_ACK = 1 << 8;
280        /// The TCP LISTEN state.
281        const LISTEN = 1 << 9;
282        /// The TCP CLOSING state.
283        const CLOSING = 1 << 10;
284    }
285}
286
287impl Matcher<TcpSocketState> for TcpStateMatcher {
288    fn matches(&self, actual: &TcpSocketState) -> bool {
289        self.contains(actual.matcher_flag())
290    }
291}
292
293/// Represents the state of a TCP socket's state machine.
294#[derive(Debug, Copy, Clone, PartialEq, Eq)]
295#[allow(missing_docs)]
296pub enum TcpSocketState {
297    Established,
298    SynSent,
299    SynRecv,
300    FinWait1,
301    FinWait2,
302    TimeWait,
303    Close,
304    CloseWait,
305    LastAck,
306    Listen,
307    Closing,
308}
309
310impl TcpSocketState {
311    fn matcher_flag(&self) -> TcpStateMatcher {
312        match self {
313            TcpSocketState::Established => TcpStateMatcher::ESTABLISHED,
314            TcpSocketState::SynSent => TcpStateMatcher::SYN_SENT,
315            TcpSocketState::SynRecv => TcpStateMatcher::SYN_RECV,
316            TcpSocketState::FinWait1 => TcpStateMatcher::FIN_WAIT1,
317            TcpSocketState::FinWait2 => TcpStateMatcher::FIN_WAIT2,
318            TcpSocketState::TimeWait => TcpStateMatcher::TIME_WAIT,
319            TcpSocketState::Close => TcpStateMatcher::CLOSE,
320            TcpSocketState::CloseWait => TcpStateMatcher::CLOSE_WAIT,
321            TcpSocketState::LastAck => TcpStateMatcher::LAST_ACK,
322            TcpSocketState::Listen => TcpStateMatcher::LISTEN,
323            TcpSocketState::Closing => TcpStateMatcher::CLOSING,
324        }
325    }
326}
327
328/// Allows code to match on properties of a TCP socket without Netstack3 Core
329/// having to specifically expose that state.
330pub trait TcpSocketProperties {
331    /// Returns whether the socket's source port is matched by the matcher.
332    fn src_port_matches(&self, matcher: &PortMatcher) -> bool;
333
334    /// Returns whether the socket's destination port is matched by the matcher.
335    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool;
336
337    /// Returns whether the socket's TCP state is matched by the matcher.
338    fn state_matches(&self, matcher: &TcpStateMatcher) -> bool;
339}
340
341impl TcpSocketProperties for Never {
342    fn src_port_matches(&self, _matcher: &PortMatcher) -> bool {
343        unimplemented!()
344    }
345
346    fn dst_port_matches(&self, _matcher: &PortMatcher) -> bool {
347        unimplemented!()
348    }
349
350    fn state_matches(&self, _matcher: &TcpStateMatcher) -> bool {
351        unimplemented!()
352    }
353}
354
355impl<T> TcpSocketProperties for &T
356where
357    T: TcpSocketProperties,
358{
359    fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
360        (*self).src_port_matches(matcher)
361    }
362
363    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
364        (*self).dst_port_matches(matcher)
365    }
366
367    fn state_matches(&self, matcher: &TcpStateMatcher) -> bool {
368        (*self).state_matches(matcher)
369    }
370}
371
372/// The top-level matcher for TCP sockets.
373pub enum TcpSocketMatcher {
374    /// Match any TCP socket without further constraints.
375    Empty,
376    /// Match on the source port.
377    SrcPort(PortMatcher),
378    /// Match on the destination port.
379    DstPort(PortMatcher),
380    /// Match on the state of the TCP state machine.
381    State(TcpStateMatcher),
382}
383
384impl<T: TcpSocketProperties> Matcher<T> for TcpSocketMatcher {
385    fn matches(&self, actual: &T) -> bool {
386        match self {
387            TcpSocketMatcher::Empty => true,
388            TcpSocketMatcher::SrcPort(matcher) => actual.src_port_matches(matcher),
389            TcpSocketMatcher::DstPort(matcher) => actual.dst_port_matches(matcher),
390            TcpSocketMatcher::State(matcher) => actual.state_matches(matcher),
391        }
392    }
393}
394
395bitflags! {
396    /// A matcher for UDP states.
397    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
398    pub struct UdpStateMatcher: u32 {
399        /// The UDP socket is bound but not connected.
400        const BOUND = 1 << 0;
401        /// The UDP socket is explicitly connected.
402        const CONNECTED = 1 << 1;
403    }
404}
405
406impl Matcher<UdpSocketState> for UdpStateMatcher {
407    fn matches(&self, actual: &UdpSocketState) -> bool {
408        self.contains(actual.matcher_flag())
409    }
410}
411
412/// Represents the state of a UDP socket.
413#[derive(Debug, Copy, Clone, PartialEq, Eq)]
414pub enum UdpSocketState {
415    /// The socket is bound to a local address and (maybe) port.
416    Bound,
417    /// The socket is connected to a remote peer and has a full 4-tuple.
418    Connected,
419}
420
421impl UdpSocketState {
422    fn matcher_flag(&self) -> UdpStateMatcher {
423        match self {
424            UdpSocketState::Bound => UdpStateMatcher::BOUND,
425            UdpSocketState::Connected => UdpStateMatcher::CONNECTED,
426        }
427    }
428}
429
430/// Allows code to match on properties of a UDP socket without Netstack3 Core
431/// having to specifically expose that state.
432pub trait UdpSocketProperties {
433    /// Returns whether the socket's source port is matched by the matcher.
434    fn src_port_matches(&self, matcher: &PortMatcher) -> bool;
435
436    /// Returns whether the socket's destination port is matched by the matcher.
437    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool;
438
439    /// Returns whether the socket's UDP state is matched by the matcher.
440    fn state_matches(&self, matcher: &UdpStateMatcher) -> bool;
441}
442
443impl UdpSocketProperties for Never {
444    fn src_port_matches(&self, _matcher: &PortMatcher) -> bool {
445        unimplemented!()
446    }
447
448    fn dst_port_matches(&self, _matcher: &PortMatcher) -> bool {
449        unimplemented!()
450    }
451
452    fn state_matches(&self, _matcher: &UdpStateMatcher) -> bool {
453        unimplemented!()
454    }
455}
456
457impl<U> UdpSocketProperties for &U
458where
459    U: UdpSocketProperties,
460{
461    fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
462        (*self).src_port_matches(matcher)
463    }
464
465    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
466        (*self).dst_port_matches(matcher)
467    }
468
469    fn state_matches(&self, matcher: &UdpStateMatcher) -> bool {
470        (*self).state_matches(matcher)
471    }
472}
473
474/// The top-level matcher for UDP sockets.
475pub enum UdpSocketMatcher {
476    /// Match any UDP socket without further constraints.
477    Empty,
478    /// Match the source port.
479    SrcPort(PortMatcher),
480    /// Match the destination port.
481    DstPort(PortMatcher),
482    /// Match the UDP state.
483    State(UdpStateMatcher),
484}
485
486impl<T: UdpSocketProperties> Matcher<T> for UdpSocketMatcher {
487    fn matches(&self, actual: &T) -> bool {
488        match self {
489            UdpSocketMatcher::Empty => true,
490            UdpSocketMatcher::SrcPort(matcher) => actual.src_port_matches(matcher),
491            UdpSocketMatcher::DstPort(matcher) => actual.dst_port_matches(matcher),
492            UdpSocketMatcher::State(matcher) => actual.state_matches(matcher),
493        }
494    }
495}
496
497/// Provides optional access to TCP socket properties.
498pub trait MaybeSocketTransportProperties {
499    /// The type that encapsulates TCP socket properties.
500    type TcpProps<'a>: TcpSocketProperties
501    where
502        Self: 'a;
503
504    /// The type that encapsulates UDP socket properties.
505    type UdpProps<'a>: UdpSocketProperties
506    where
507        Self: 'a;
508
509    /// Returns TCP socket properties if the socket is a TCP socket.
510    fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>>;
511
512    /// Returns UDP socket properties if the socket is a UDP socket.
513    fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>>;
514}
515
516impl MaybeSocketTransportProperties for Never {
517    type TcpProps<'a>
518        = Never
519    where
520        Self: 'a;
521
522    type UdpProps<'a>
523        = Never
524    where
525        Self: 'a;
526
527    fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
528        unimplemented!()
529    }
530
531    fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
532        unimplemented!()
533    }
534}
535
536/// A matcher for the transport protocol of a socket.
537pub enum SocketTransportProtocolMatcher {
538    /// Match against a TCP socket.
539    Tcp(TcpSocketMatcher),
540    /// Match against a UDP socket.
541    Udp(UdpSocketMatcher),
542}
543
544impl<T: MaybeSocketTransportProperties> Matcher<T> for SocketTransportProtocolMatcher {
545    fn matches(&self, actual: &T) -> bool {
546        match self {
547            SocketTransportProtocolMatcher::Tcp(tcp_matcher) => {
548                actual.tcp_socket_properties().map_or(false, |props| tcp_matcher.matches(props))
549            }
550            SocketTransportProtocolMatcher::Udp(udp_matcher) => {
551                actual.udp_socket_properties().map_or(false, |props| udp_matcher.matches(props))
552            }
553        }
554    }
555}
556
557/// A matcher for IP addresses.
558#[derive(Clone, Derivative)]
559#[derivative(Debug)]
560pub enum AddressMatcherType<A: IpAddress> {
561    /// A subnet that must contain the address.
562    #[derivative(Debug = "transparent")]
563    Subnet(SubnetMatcher<A>),
564    /// An inclusive range of IP addresses that must contain the address.
565    Range(RangeInclusive<A>),
566}
567
568impl<A: IpAddress> Matcher<A> for AddressMatcherType<A> {
569    fn matches(&self, actual: &A) -> bool {
570        match self {
571            Self::Subnet(subnet_matcher) => subnet_matcher.matches(actual),
572            Self::Range(range) => range.contains(actual),
573        }
574    }
575}
576
577/// A matcher for IP addresses.
578#[derive(Clone, Debug)]
579pub struct AddressMatcher<A: IpAddress> {
580    /// The type of the address matcher.
581    pub matcher: AddressMatcherType<A>,
582    /// Whether to check for an "inverse" or "negative" match (in which case,
583    /// if the matcher criteria do *not* apply, it *is* considered a match, and
584    /// vice versa).
585    pub invert: bool,
586}
587
588impl<A: IpAddress> InspectableValue for AddressMatcher<A> {
589    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
590        let AddressMatcher { matcher, invert } = self;
591
592        inspector.record_child(name, |inspector| {
593            inspector.record_bool("invert", *invert);
594            match matcher {
595                AddressMatcherType::Subnet(SubnetMatcher(subnet)) => {
596                    inspector.record_display("subnet", subnet)
597                }
598                AddressMatcherType::Range(range) => {
599                    inspector.record_display("start", range.start());
600                    inspector.record_display("end", range.end());
601                }
602            }
603        })
604    }
605}
606
607impl<A: IpAddress> Matcher<A> for AddressMatcher<A> {
608    fn matches(&self, addr: &A) -> bool {
609        let Self { matcher, invert } = self;
610        matcher.matches(addr) ^ *invert
611    }
612}
613
614/// An address matcher that matches any IP version as specified at runtime.
615pub enum AddressMatcherEither {
616    /// The top-level IPv4 address matcher.
617    V4(AddressMatcher<Ipv4Addr>),
618    /// The top-level IPv6 address matcher.
619    V6(AddressMatcher<Ipv6Addr>),
620}
621
622impl Matcher<IpAddr> for AddressMatcherEither {
623    fn matches(&self, addr: &IpAddr) -> bool {
624        match self {
625            AddressMatcherEither::V4(matcher) => match addr {
626                IpAddr::V4(addr) => matcher.matches(addr),
627                IpAddr::V6(_) => false,
628            },
629            AddressMatcherEither::V6(matcher) => match addr {
630                IpAddr::V4(_) => false,
631                IpAddr::V6(addr) => matcher.matches(addr),
632            },
633        }
634    }
635}
636
637/// Allows code to match on properties of a socket without Netstack3 Core
638/// having to specifically expose that state.
639pub trait IpSocketProperties<DeviceClass> {
640    /// Returns whether the provided IP version matches the socket.
641    fn family_matches(&self, family: &net_types::ip::IpVersion) -> bool;
642
643    /// Returns whether the provided address matcher matches the socket's source
644    /// address.
645    fn src_addr_matches(&self, addr: &AddressMatcherEither) -> bool;
646
647    /// Returns whether the provided address matcher matches the socket's
648    /// destination address.
649    fn dst_addr_matches(&self, addr: &AddressMatcherEither) -> bool;
650
651    /// Returns whether the transport protocol matches the socket's
652    /// transport-layer information.
653    fn transport_protocol_matches(&self, matcher: &SocketTransportProtocolMatcher) -> bool;
654
655    /// Returns whether the provided interface matcher matches the socket's
656    /// bound interface, if present.
657    fn bound_interface_matches(&self, iface: &BoundInterfaceMatcher<DeviceClass>) -> bool;
658
659    /// Returns whether the provided cookie matcher matches the socket's cookie.
660    fn cookie_matches(&self, cookie: &SocketCookieMatcher) -> bool;
661
662    /// Returns whether the provided mark matcher matches the corresponding mark.
663    fn mark_matches(&self, matcher: &MarkInDomainMatcher) -> bool;
664}
665
666/// The top-level matcher for IP sockets.
667pub enum IpSocketMatcher<DeviceClass> {
668    /// Matches the socket's address family.
669    Family(net_types::ip::IpVersion),
670    /// Matches the socket's source address.
671    SrcAddr(AddressMatcherEither),
672    /// Matches the socket's destination address.
673    DstAddr(AddressMatcherEither),
674    /// Matches the socket's transport protocol.
675    Proto(SocketTransportProtocolMatcher),
676    /// Matches the socket's bound interface.
677    BoundInterface(BoundInterfaceMatcher<DeviceClass>),
678    /// Matches the socket's cookie.
679    Cookie(SocketCookieMatcher),
680    /// Matches the socket's mark.
681    Mark(MarkInDomainMatcher),
682}
683
684impl<DeviceClass, S: IpSocketProperties<DeviceClass>> Matcher<S> for IpSocketMatcher<DeviceClass> {
685    fn matches(&self, actual: &S) -> bool {
686        match self {
687            IpSocketMatcher::Family(family) => actual.family_matches(family),
688            IpSocketMatcher::SrcAddr(addr) => actual.src_addr_matches(addr),
689            IpSocketMatcher::DstAddr(addr) => actual.dst_addr_matches(addr),
690            IpSocketMatcher::Proto(proto) => actual.transport_protocol_matches(proto),
691            IpSocketMatcher::BoundInterface(iface) => actual.bound_interface_matches(iface),
692            IpSocketMatcher::Cookie(cookie) => actual.cookie_matches(cookie),
693            IpSocketMatcher::Mark(mark) => actual.mark_matches(mark),
694        }
695    }
696}
697
698/// Allows code to take an opaque matcher that works on IP sockets without
699/// needing to know the type(s) of the underlying matcher(s).
700pub trait IpSocketPropertiesMatcher<DeviceClass> {
701    /// Whether the matcher matches `actual`.
702    fn matches_ip_socket<S: IpSocketProperties<DeviceClass>>(&self, actual: &S) -> bool;
703}
704
705impl<DeviceClass> IpSocketPropertiesMatcher<DeviceClass> for IpSocketMatcher<DeviceClass> {
706    fn matches_ip_socket<S: IpSocketProperties<DeviceClass>>(&self, actual: &S) -> bool {
707        self.matches(actual)
708    }
709}
710
711impl<DeviceClass> IpSocketPropertiesMatcher<DeviceClass> for [IpSocketMatcher<DeviceClass>] {
712    fn matches_ip_socket<S: IpSocketProperties<DeviceClass>>(&self, actual: &S) -> bool {
713        self.iter().all(|matcher| matcher.matches(actual))
714    }
715}
716
717#[cfg(any(test, feature = "testutils"))]
718pub(crate) mod testutil {
719    use alloc::string::String;
720    use core::num::NonZeroU64;
721
722    use crate::matchers::InterfaceProperties;
723    use crate::testutil::{FakeDeviceClass, FakeStrongDeviceId, FakeWeakDeviceId};
724    use crate::{DeviceIdentifier, StrongDeviceIdentifier};
725
726    /// A fake device ID for testing matchers.
727    #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
728    #[allow(missing_docs)]
729    pub struct FakeMatcherDeviceId {
730        pub id: NonZeroU64,
731        pub name: String,
732        pub class: FakeDeviceClass,
733    }
734
735    impl FakeMatcherDeviceId {
736        /// Returns a [`FakeMatcherDeviceId`] for an arbitrary WLAN interface.
737        ///
738        /// The interface returned will always be identical.
739        pub fn wlan_interface() -> FakeMatcherDeviceId {
740            FakeMatcherDeviceId {
741                id: NonZeroU64::new(1).unwrap(),
742                name: String::from("wlan"),
743                class: FakeDeviceClass::Wlan,
744            }
745        }
746
747        /// Returns a [`FakeMatcherDeviceId`] for an arbitrary Ethernet interface.
748        ///
749        /// The interface returned will always be identical.
750        pub fn ethernet_interface() -> FakeMatcherDeviceId {
751            FakeMatcherDeviceId {
752                id: NonZeroU64::new(2).unwrap(),
753                name: String::from("eth"),
754                class: FakeDeviceClass::Ethernet,
755            }
756        }
757    }
758
759    impl StrongDeviceIdentifier for FakeMatcherDeviceId {
760        type Weak = FakeWeakDeviceId<Self>;
761
762        fn downgrade(&self) -> Self::Weak {
763            FakeWeakDeviceId(self.clone())
764        }
765    }
766
767    impl DeviceIdentifier for FakeMatcherDeviceId {
768        fn is_loopback(&self) -> bool {
769            false
770        }
771    }
772
773    impl FakeStrongDeviceId for FakeMatcherDeviceId {
774        fn is_alive(&self) -> bool {
775            true
776        }
777    }
778
779    impl PartialEq<FakeWeakDeviceId<FakeMatcherDeviceId>> for FakeMatcherDeviceId {
780        fn eq(&self, FakeWeakDeviceId(other): &FakeWeakDeviceId<FakeMatcherDeviceId>) -> bool {
781            self == other
782        }
783    }
784
785    impl InterfaceProperties<FakeDeviceClass> for FakeMatcherDeviceId {
786        fn id_matches(&self, id: &NonZeroU64) -> bool {
787            &self.id == id
788        }
789
790        fn name_matches(&self, name: &str) -> bool {
791            &self.name == name
792        }
793
794        fn device_class_matches(&self, class: &FakeDeviceClass) -> bool {
795            &self.class == class
796        }
797    }
798}
799
800#[cfg(test)]
801mod tests {
802    use ip_test_macro::ip_test;
803    use net_types::Witness;
804    use net_types::ip::{Ip, IpVersion, Ipv4, Ipv6};
805    use test_case::test_case;
806
807    use super::*;
808    use crate::testutil::{FakeDeviceClass, FakeMatcherDeviceId, TestIpExt};
809
810    /// Only matches `true`.
811    #[derive(Debug)]
812    struct TrueMatcher;
813
814    impl Matcher<bool> for TrueMatcher {
815        fn matches(&self, actual: &bool) -> bool {
816            *actual
817        }
818    }
819
820    #[test]
821    fn test_optional_matcher_optional_value() {
822        assert!(TrueMatcher.matches(&true));
823        assert!(!TrueMatcher.matches(&false));
824
825        assert!(TrueMatcher.required_matches(Some(&true)));
826        assert!(!TrueMatcher.required_matches(Some(&false)));
827        assert!(!TrueMatcher.required_matches(None));
828
829        assert!(Some(TrueMatcher).matches(&true));
830        assert!(!Some(TrueMatcher).matches(&false));
831        assert!(None::<TrueMatcher>.matches(&true));
832        assert!(None::<TrueMatcher>.matches(&false));
833
834        assert!(Some(TrueMatcher).required_matches(Some(&true)));
835        assert!(!Some(TrueMatcher).required_matches(Some(&false)));
836        assert!(!Some(TrueMatcher).required_matches(None));
837        assert!(None::<TrueMatcher>.required_matches(Some(&true)));
838        assert!(None::<TrueMatcher>.required_matches(Some(&false)));
839        assert!(None::<TrueMatcher>.required_matches(None));
840    }
841
842    #[test_case(
843        InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id),
844        FakeMatcherDeviceId::wlan_interface() => true
845    )]
846    #[test_case(
847        InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id),
848        FakeMatcherDeviceId::ethernet_interface() => false
849    )]
850    #[test_case(
851        InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name),
852        FakeMatcherDeviceId::wlan_interface() => true
853    )]
854    #[test_case(
855        InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name),
856        FakeMatcherDeviceId::ethernet_interface() => false
857    )]
858    #[test_case(
859        InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan),
860        FakeMatcherDeviceId::wlan_interface() => true
861    )]
862    #[test_case(
863        InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan),
864        FakeMatcherDeviceId::ethernet_interface() => false
865    )]
866    fn interface_matcher(
867        matcher: InterfaceMatcher<FakeDeviceClass>,
868        device: FakeMatcherDeviceId,
869    ) -> bool {
870        matcher.matches(&device)
871    }
872
873    #[test_case(BoundInterfaceMatcher::Unbound, None => true)]
874    #[test_case(
875        BoundInterfaceMatcher::Unbound,
876        Some(FakeMatcherDeviceId::wlan_interface()) => false
877    )]
878    #[test_case(
879        BoundInterfaceMatcher::Bound(
880            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
881        ),
882        None => false
883    )]
884    #[test_case(
885        BoundInterfaceMatcher::Bound(
886            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
887        ),
888        Some(FakeMatcherDeviceId::wlan_interface()) => true
889    )]
890    #[test_case(
891        BoundInterfaceMatcher::Bound(
892            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
893        ),
894        Some(FakeMatcherDeviceId::ethernet_interface()) => false
895    )]
896    #[test_case(
897        BoundInterfaceMatcher::Bound(
898            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
899        ),
900        None => false
901    )]
902    #[test_case(
903        BoundInterfaceMatcher::Bound(
904            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
905        ),
906        Some(FakeMatcherDeviceId::wlan_interface()) => true
907    )]
908    #[test_case(
909        BoundInterfaceMatcher::Bound(
910            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
911        ),
912        Some(FakeMatcherDeviceId::ethernet_interface()) => false
913    )]
914    #[test_case(
915        BoundInterfaceMatcher::Bound(
916            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
917        ),
918        None => false
919    )]
920    #[test_case(
921        BoundInterfaceMatcher::Bound(
922            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
923        ),
924        Some(FakeMatcherDeviceId::wlan_interface()) => true
925    )]
926    #[test_case(
927        BoundInterfaceMatcher::Bound(
928            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
929        ),
930        Some(FakeMatcherDeviceId::ethernet_interface()) => false
931    )]
932    fn bound_interface_matcher(
933        matcher: BoundInterfaceMatcher<FakeDeviceClass>,
934        device: Option<FakeMatcherDeviceId>,
935    ) -> bool {
936        matcher.matches(&device.as_ref())
937    }
938
939    #[ip_test(I)]
940    fn subnet_matcher<I: Ip + TestIpExt>() {
941        let matcher = SubnetMatcher(I::TEST_ADDRS.subnet);
942        assert!(matcher.matches(&I::TEST_ADDRS.local_ip));
943        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
944    }
945
946    #[test_case(MarkMatcher::Unmarked, Mark(None) => true; "unmarked matches none")]
947    #[test_case(MarkMatcher::Unmarked, Mark(Some(0)) => false; "unmarked does not match some")]
948    #[test_case(MarkMatcher::Marked {
949        mask: 1,
950        start: 0,
951        end: 0,
952        invert: false,
953    }, Mark(None) => false; "marked does not match none")]
954    #[test_case(MarkMatcher::Marked {
955        mask: 1,
956        start: 0,
957        end: 0,
958        invert: false,
959    }, Mark(Some(0)) => true; "marked 0 mask 1 matches 0")]
960    #[test_case(MarkMatcher::Marked {
961        mask: 1,
962        start: 0,
963        end: 0,
964        invert: false,
965    }, Mark(Some(1)) => false; "marked 0 mask 1 does not match 1")]
966    #[test_case(MarkMatcher::Marked {
967        mask: 1,
968        start: 0,
969        end: 0,
970        invert: false,
971    }, Mark(Some(2)) => true; "marked 0 mask 1 matches 2")]
972    #[test_case(MarkMatcher::Marked {
973        mask: 1,
974        start: 0,
975        end: 0,
976        invert: false,
977    }, Mark(Some(3)) => false; "marked 0 mask 1 does not match 3")]
978    #[test_case(MarkMatcher::Marked {
979        mask: !0,
980        start: 0,
981        end: 10,
982        invert: true,
983    }, Mark(Some(5)) => false; "marked invert no match in range")]
984    #[test_case(MarkMatcher::Marked {
985        mask: !0,
986        start: 0,
987        end: 10,
988        invert: true,
989    }, Mark(Some(11)) => true; "marked invert matches out of range")]
990    fn mark_matcher(matcher: MarkMatcher, mark: Mark) -> bool {
991        matcher.matches(&mark)
992    }
993
994    #[test_case(
995        MarkMatchers::new(
996            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
997            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
998        ),
999        Marks::new([]) => true;
1000        "all unmarked matches empty"
1001    )]
1002    #[test_case(
1003        MarkMatchers::new(
1004            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1005            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1006        ),
1007        Marks::new([(MarkDomain::Mark1, 1)]) => false;
1008        "all unmarked does not match mark1"
1009    )]
1010    #[test_case(
1011        MarkMatchers::new(
1012            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1013            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1014        ),
1015        Marks::new([(MarkDomain::Mark2, 1)]) => false;
1016        "all unmarked does not match mark2"
1017    )]
1018    #[test_case(
1019        MarkMatchers::new(
1020            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1021            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1022        ),
1023        Marks::new([
1024            (MarkDomain::Mark1, 1),
1025            (MarkDomain::Mark2, 1),
1026        ]) => false;
1027        "all unmarked does not match mark1 and mark2"
1028    )]
1029    #[test_case(
1030        MarkMatchers::new(
1031            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1032            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1033        ),
1034        Marks::new([(MarkDomain::Mark1, 1)]) => true;
1035        "mark1 marked matches"
1036    )]
1037    #[test_case(
1038        MarkMatchers::new(
1039            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1040            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1041        ),
1042        Marks::new([(MarkDomain::Mark1, 2)]) => false;
1043        "mark1 marked no match"
1044    )]
1045    #[test_case(
1046        MarkMatchers::new(
1047            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1048            (MarkDomain::Mark2, MarkMatcher::Marked { mask: !0, start: 2, end: 2, invert: false })]
1049        ),
1050        Marks::new([(MarkDomain::Mark1, 1), (MarkDomain::Mark2, 2)]) => true;
1051        "all marked matches"
1052    )]
1053    #[test_case(
1054        MarkMatchers::new(
1055            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1056            (MarkDomain::Mark2, MarkMatcher::Marked { mask: !0, start: 2, end: 2, invert: false })]
1057        ),
1058        Marks::new([(MarkDomain::Mark1, 1), (MarkDomain::Mark2, 3)]) => false;
1059        "all marked no match mark2"
1060    )]
1061    fn mark_matchers(matchers: MarkMatchers, marks: Marks) -> bool {
1062        matchers.matches(&marks)
1063    }
1064
1065    #[test_case(SocketCookieMatcher { cookie: 123, invert: false }, 123 => true)]
1066    #[test_case(SocketCookieMatcher { cookie: 123, invert: false }, 456 => false)]
1067    #[test_case(SocketCookieMatcher { cookie: 123, invert: true }, 123 => false)]
1068    #[test_case(SocketCookieMatcher { cookie: 123, invert: true }, 456 => true)]
1069    fn socket_cookie_matcher(matcher: SocketCookieMatcher, actual: u64) -> bool {
1070        matcher.matches(&actual)
1071    }
1072
1073    #[test_case(PortMatcher { range: 10..=20, invert: false }, 9 => false)]
1074    #[test_case(PortMatcher { range: 10..=20, invert: false }, 10 => true)]
1075    #[test_case(PortMatcher { range: 10..=20, invert: false }, 15 => true)]
1076    #[test_case(PortMatcher { range: 10..=20, invert: false }, 20 => true)]
1077    #[test_case(PortMatcher { range: 10..=20, invert: false }, 21 => false)]
1078    #[test_case(PortMatcher { range: 10..=20, invert: true }, 9 => true)]
1079    #[test_case(PortMatcher { range: 10..=20, invert: true }, 10 => false)]
1080    #[test_case(PortMatcher { range: 10..=20, invert: true }, 15 => false)]
1081    #[test_case(PortMatcher { range: 10..=20, invert: true }, 20 => false)]
1082    #[test_case(PortMatcher { range: 10..=20, invert: true }, 21 => true)]
1083    fn port_matcher(matcher: PortMatcher, actual: u16) -> bool {
1084        matcher.matches(&actual)
1085    }
1086
1087    struct FakeTcpSocket {
1088        src_port: u16,
1089        dst_port: u16,
1090        state: TcpSocketState,
1091    }
1092
1093    impl MaybeSocketTransportProperties for FakeTcpSocket {
1094        type TcpProps<'a>
1095            = Self
1096        where
1097            Self: 'a;
1098
1099        type UdpProps<'a>
1100            = Never
1101        where
1102            Self: 'a;
1103
1104        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1105            Some(self)
1106        }
1107
1108        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1109            None
1110        }
1111    }
1112
1113    impl TcpSocketProperties for FakeTcpSocket {
1114        fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
1115            matcher.matches(&self.src_port)
1116        }
1117
1118        fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
1119            matcher.matches(&self.dst_port)
1120        }
1121
1122        fn state_matches(&self, matcher: &TcpStateMatcher) -> bool {
1123            matcher.matches(&self.state)
1124        }
1125    }
1126
1127    struct FakeUdpSocket {
1128        src_port: u16,
1129        dst_port: u16,
1130        state: UdpSocketState,
1131    }
1132
1133    impl MaybeSocketTransportProperties for FakeUdpSocket {
1134        type TcpProps<'a>
1135            = Never
1136        where
1137            Self: 'a;
1138
1139        type UdpProps<'a>
1140            = Self
1141        where
1142            Self: 'a;
1143
1144        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1145            None
1146        }
1147
1148        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1149            Some(self)
1150        }
1151    }
1152
1153    impl UdpSocketProperties for FakeUdpSocket {
1154        fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
1155            matcher.matches(&self.src_port)
1156        }
1157
1158        fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
1159            matcher.matches(&self.dst_port)
1160        }
1161
1162        fn state_matches(&self, matcher: &UdpStateMatcher) -> bool {
1163            matcher.matches(&self.state)
1164        }
1165    }
1166
1167    struct FakeIpSocket<I, T>
1168    where
1169        I: TestIpExt,
1170        T: MaybeSocketTransportProperties,
1171    {
1172        src_ip: I::Addr,
1173        dst_ip: I::Addr,
1174        proto: T,
1175        intf: Option<FakeMatcherDeviceId>,
1176        cookie: u64,
1177        marks: Marks,
1178    }
1179
1180    impl<I, T> MaybeSocketTransportProperties for FakeIpSocket<I, T>
1181    where
1182        I: TestIpExt,
1183        T: MaybeSocketTransportProperties,
1184    {
1185        type TcpProps<'a>
1186            = T::TcpProps<'a>
1187        where
1188            Self: 'a;
1189
1190        type UdpProps<'a>
1191            = T::UdpProps<'a>
1192        where
1193            Self: 'a;
1194
1195        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1196            self.proto.tcp_socket_properties()
1197        }
1198
1199        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1200            self.proto.udp_socket_properties()
1201        }
1202    }
1203
1204    impl<I, T> IpSocketProperties<FakeDeviceClass> for FakeIpSocket<I, T>
1205    where
1206        I: TestIpExt,
1207        T: MaybeSocketTransportProperties,
1208    {
1209        fn family_matches(&self, family: &net_types::ip::IpVersion) -> bool {
1210            *family == I::VERSION
1211        }
1212
1213        fn src_addr_matches(&self, addr: &AddressMatcherEither) -> bool {
1214            addr.matches(&self.src_ip.into())
1215        }
1216
1217        fn dst_addr_matches(&self, addr: &AddressMatcherEither) -> bool {
1218            addr.matches(&self.dst_ip.into())
1219        }
1220
1221        fn transport_protocol_matches(&self, matcher: &SocketTransportProtocolMatcher) -> bool {
1222            matcher.matches(self)
1223        }
1224
1225        fn bound_interface_matches(&self, iface: &BoundInterfaceMatcher<FakeDeviceClass>) -> bool {
1226            iface.matches(&self.intf.as_ref())
1227        }
1228
1229        fn cookie_matches(&self, cookie: &SocketCookieMatcher) -> bool {
1230            cookie.matches(&self.cookie)
1231        }
1232
1233        fn mark_matches(&self, matcher: &MarkInDomainMatcher) -> bool {
1234            matcher.matcher.matches(self.marks.get(matcher.domain))
1235        }
1236    }
1237
1238    #[test_case(
1239        TcpSocketMatcher::Empty,
1240        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1241        "empty matcher"
1242    )]
1243    #[test_case(
1244        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false }),
1245        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1246        "src_port match"
1247    )]
1248    #[test_case(
1249        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false }),
1250        FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established } => false;
1251        "src_port no match"
1252    )]
1253    #[test_case(
1254        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: true }),
1255        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => false;
1256        "src_port invert no match"
1257    )]
1258    #[test_case(
1259        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: true }),
1260        FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established } => true;
1261        "src_port invert match"
1262    )]
1263    #[test_case(
1264        TcpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1265        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1266        "dst_port match"
1267    )]
1268    #[test_case(
1269        TcpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1270        FakeTcpSocket { src_port: 80, dst_port: 12346, state: TcpSocketState::Established } => false;
1271        "dst_port no match"
1272    )]
1273    #[test_case(
1274        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED),
1275        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1276        "state match"
1277    )]
1278    #[test_case(
1279        TcpSocketMatcher::State(TcpStateMatcher::SYN_SENT),
1280        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => false;
1281        "state no match"
1282    )]
1283    #[test_case(
1284        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1285        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1286        "state multi match established"
1287    )]
1288    #[test_case(
1289        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1290        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::SynSent } => true;
1291        "state multi match syn_sent"
1292    )]
1293    #[test_case(
1294        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1295        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::FinWait1 } => false;
1296        "state multi no match"
1297    )]
1298    fn tcp_socket_matcher(matcher: TcpSocketMatcher, socket: FakeTcpSocket) -> bool {
1299        matcher.matches(&socket)
1300    }
1301
1302    #[test_case(
1303        UdpSocketMatcher::Empty,
1304        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1305        "empty matcher"
1306    )]
1307    #[test_case(
1308        UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false }),
1309        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1310        "src_port match"
1311    )]
1312    #[test_case(
1313        UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false }),
1314        FakeUdpSocket { src_port: 54, dst_port: 12345, state: UdpSocketState::Bound } => false;
1315        "src_port no match"
1316    )]
1317    #[test_case(
1318        UdpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1319        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1320        "dst_port match"
1321    )]
1322    #[test_case(
1323        UdpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1324        FakeUdpSocket { src_port: 53, dst_port: 12346, state: UdpSocketState::Bound } => false;
1325        "dst_port no match"
1326    )]
1327    #[test_case(
1328        UdpSocketMatcher::State(UdpStateMatcher::BOUND),
1329        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1330        "state match bound"
1331    )]
1332    #[test_case(
1333        UdpSocketMatcher::State(UdpStateMatcher::CONNECTED),
1334        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => false;
1335        "state no match connected"
1336    )]
1337    #[test_case(
1338        UdpSocketMatcher::State(UdpStateMatcher::BOUND | UdpStateMatcher::CONNECTED),
1339        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1340        "state multi match bound"
1341    )]
1342    #[test_case(
1343        UdpSocketMatcher::State(UdpStateMatcher::BOUND | UdpStateMatcher::CONNECTED),
1344        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Connected } => true;
1345        "state multi match connected"
1346    )]
1347    fn udp_socket_matcher(matcher: UdpSocketMatcher, socket: FakeUdpSocket) -> bool {
1348        matcher.matches(&socket)
1349    }
1350
1351    #[ip_test(I)]
1352    #[test_case(
1353        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Tcp(TcpSocketMatcher::Empty)),
1354        FakeIpSocket {
1355            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1356            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1357            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1358            cookie: 0,
1359            intf: None,
1360            marks: Marks::default(),
1361        } => true;
1362        "tcp empty"
1363    )]
1364    #[test_case(
1365        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Tcp(TcpSocketMatcher::Empty)),
1366        FakeIpSocket {
1367            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1368            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1369            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1370            cookie: 0,
1371            intf: None,
1372            marks: Marks::default(),
1373        } => false;
1374        "tcp empty no match udp"
1375    )]
1376    #[test_case(
1377        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Udp(UdpSocketMatcher::Empty)),
1378        FakeIpSocket {
1379            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1380            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1381            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1382            cookie: 0,
1383            intf: None,
1384            marks: Marks::default(),
1385        } => false;
1386        "udp empty no match tcp"
1387    )]
1388    #[test_case(
1389        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Udp(UdpSocketMatcher::Empty)),
1390        FakeIpSocket {
1391            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1392            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1393            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1394            cookie: 0,
1395            intf: None,
1396            marks: Marks::default(),
1397        } => true;
1398        "udp empty"
1399    )]
1400    #[test_case(
1401        IpSocketMatcher::Proto(
1402            SocketTransportProtocolMatcher::Tcp(
1403                TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false })
1404            )
1405        ),
1406        FakeIpSocket {
1407            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1408            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1409            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1410            cookie: 0,
1411            intf: None,
1412            marks: Marks::default(),
1413        } => true;
1414        "tcp src_port match"
1415    )]
1416    #[test_case(
1417        IpSocketMatcher::Proto(
1418            SocketTransportProtocolMatcher::Tcp(
1419                TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false })
1420            )
1421        ),
1422        FakeIpSocket {
1423            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1424            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1425            proto: FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established },
1426            cookie: 0,
1427            intf: None,
1428            marks: Marks::default(),
1429        } => false;
1430        "tcp src_port no match"
1431    )]
1432    #[test_case(
1433        IpSocketMatcher::Proto(
1434            SocketTransportProtocolMatcher::Udp(
1435                UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false })
1436            )
1437        ),
1438        FakeIpSocket {
1439            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1440            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1441            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1442            cookie: 0,
1443            intf: None,
1444            marks: Marks::default(),
1445        } => true;
1446        "udp src_port match"
1447    )]
1448    #[test_case(
1449        IpSocketMatcher::Proto(
1450            SocketTransportProtocolMatcher::Udp(
1451                UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false })
1452            )
1453        ),
1454        FakeIpSocket {
1455            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1456            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1457            proto: FakeUdpSocket { src_port: 54, dst_port: 12345, state: UdpSocketState::Bound },
1458            cookie: 0,
1459            intf: None,
1460            marks: Marks::default(),
1461        } => false;
1462        "udp src_port no match"
1463    )]
1464    #[test_case(
1465        IpSocketMatcher::Cookie(SocketCookieMatcher { cookie: 123, invert: false }),
1466        FakeIpSocket {
1467            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1468            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1469            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1470            cookie: 123,
1471            intf: None,
1472            marks: Marks::default(),
1473        } => true;
1474        "cookie match"
1475    )]
1476    #[test_case(
1477        IpSocketMatcher::Cookie(SocketCookieMatcher { cookie: 123, invert: false }),
1478        FakeIpSocket {
1479            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1480            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1481            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1482            cookie: 456,
1483            intf: None,
1484            marks: Marks::default(),
1485        } => false;
1486        "cookie no match"
1487    )]
1488    #[test_case(
1489        IpSocketMatcher::Mark(MarkInDomainMatcher {
1490            domain: MarkDomain::Mark1,
1491            matcher: MarkMatcher::Unmarked,
1492        }),
1493        FakeIpSocket {
1494            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1495            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1496            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1497            cookie: 0,
1498            intf: None,
1499            marks: Marks::default(),
1500        } => true;
1501        "mark1 unmarked match"
1502    )]
1503    #[test_case(
1504        IpSocketMatcher::Mark(MarkInDomainMatcher {
1505            domain: MarkDomain::Mark1,
1506            matcher: MarkMatcher::Unmarked,
1507        }),
1508        FakeIpSocket {
1509            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1510            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1511            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1512            cookie: 0,
1513            intf: None,
1514            marks: Marks::new([(MarkDomain::Mark1, 1)]),
1515        } => false;
1516        "mark1 unmarked no match"
1517    )]
1518    #[test_case(
1519        IpSocketMatcher::Mark(MarkInDomainMatcher {
1520            domain: MarkDomain::Mark2,
1521            matcher: MarkMatcher::Unmarked,
1522        }),
1523        FakeIpSocket {
1524            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1525            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1526            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1527            cookie: 0,
1528            intf: None,
1529            marks: Marks::default(),
1530        } => true;
1531        "mark2 unmarked match"
1532    )]
1533    #[test_case(
1534        IpSocketMatcher::Mark(MarkInDomainMatcher {
1535            domain: MarkDomain::Mark2,
1536            matcher: MarkMatcher::Unmarked,
1537        }),
1538        FakeIpSocket {
1539            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1540            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1541            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1542            cookie: 0,
1543            intf: None,
1544            marks: Marks::new([(MarkDomain::Mark2, 1)]),
1545        } => false;
1546        "mark2 unmarked no match"
1547    )]
1548    #[test_case(
1549        IpSocketMatcher::BoundInterface(BoundInterfaceMatcher::Bound(
1550            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
1551        )),
1552        FakeIpSocket {
1553            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1554            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1555            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1556            cookie: 0,
1557            intf: Some(FakeMatcherDeviceId::wlan_interface()),
1558            marks: Marks::default(),
1559        } => true;
1560        "bound_interface match"
1561    )]
1562    #[test_case(
1563        IpSocketMatcher::BoundInterface(BoundInterfaceMatcher::Bound(
1564            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
1565        )),
1566        FakeIpSocket {
1567            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1568            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1569            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1570            cookie: 0,
1571            intf: Some(FakeMatcherDeviceId::ethernet_interface()),
1572            marks: Marks::default(),
1573        } => false;
1574        "bound_interface no match"
1575    )]
1576    fn ip_socket_matcher<I: TestIpExt, T: MaybeSocketTransportProperties>(
1577        matcher: IpSocketMatcher<FakeDeviceClass>,
1578        socket: FakeIpSocket<I, T>,
1579    ) -> bool {
1580        matcher.matches(&socket)
1581    }
1582
1583    #[ip_test(I)]
1584    fn address_matcher_type<I: TestIpExt>() {
1585        let local_ip = I::TEST_ADDRS.local_ip.get();
1586        let remote_ip = I::TEST_ADDRS.remote_ip.get();
1587
1588        let matcher = AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet));
1589        assert!(matcher.matches(&local_ip));
1590        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1591
1592        let matcher = AddressMatcherType::Range(local_ip..=remote_ip);
1593        assert!(matcher.matches(&local_ip));
1594        assert!(matcher.matches(&remote_ip));
1595        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1596    }
1597
1598    #[ip_test(I)]
1599    fn address_matcher<I: TestIpExt>() {
1600        let local_ip = I::TEST_ADDRS.local_ip.get();
1601        let remote_ip = I::TEST_ADDRS.remote_ip.get();
1602
1603        let matcher = AddressMatcher {
1604            matcher: AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet)),
1605            invert: false,
1606        };
1607        assert!(matcher.matches(&local_ip));
1608        assert!(matcher.matches(&remote_ip));
1609        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1610
1611        let matcher = AddressMatcher {
1612            matcher: AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet)),
1613            invert: true,
1614        };
1615        assert!(!matcher.matches(&local_ip));
1616        assert!(!matcher.matches(&remote_ip));
1617        assert!(matcher.matches(&I::get_other_remote_ip_address(1)));
1618
1619        let matcher = AddressMatcher {
1620            matcher: AddressMatcherType::Range(local_ip..=remote_ip),
1621            invert: false,
1622        };
1623        assert!(matcher.matches(&local_ip));
1624        assert!(matcher.matches(&remote_ip));
1625        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1626
1627        let matcher = AddressMatcher {
1628            matcher: AddressMatcherType::Range(local_ip..=remote_ip),
1629            invert: true,
1630        };
1631        assert!(!matcher.matches(&local_ip));
1632        assert!(!matcher.matches(&remote_ip));
1633        assert!(matcher.matches(&I::get_other_remote_ip_address(1)));
1634    }
1635
1636    #[test]
1637    fn agnostic_address_matcher() {
1638        let v4_addr = IpAddr::V4(Ipv4Addr::new([192, 0, 2, 1]));
1639        let v6_addr = IpAddr::V6(Ipv6Addr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 1]));
1640
1641        let v4_subnet = Subnet::new(Ipv4Addr::new([192, 0, 2, 0]), 24).unwrap();
1642        let v6_subnet = Subnet::new(Ipv6Addr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 0]), 32).unwrap();
1643
1644        let v4_matcher = AddressMatcherEither::V4(AddressMatcher {
1645            matcher: AddressMatcherType::Subnet(SubnetMatcher(v4_subnet)),
1646            invert: false,
1647        });
1648        assert!(v4_matcher.matches(&v4_addr));
1649        assert!(!v4_matcher.matches(&v6_addr));
1650
1651        let v6_matcher = AddressMatcherEither::V6(AddressMatcher {
1652            matcher: AddressMatcherType::Subnet(SubnetMatcher(v6_subnet)),
1653            invert: false,
1654        });
1655        assert!(!v6_matcher.matches(&v4_addr));
1656        assert!(v6_matcher.matches(&v6_addr));
1657    }
1658
1659    #[test_case(IpSocketMatcher::Family(IpVersion::V4) => true; "v4 family matcher on v4 socket")]
1660    #[test_case(IpSocketMatcher::Family(IpVersion::V6) => false; "v6 family matcher on v4 socket")]
1661    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V4(AddressMatcher {
1662        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv4::TEST_ADDRS.subnet)),
1663        invert: false,
1664    })) => true; "src_addr match")]
1665    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V4(AddressMatcher {
1666        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv4Addr::new([0, 0, 0, 0]), 32).unwrap())),
1667        invert: false,
1668    })) => false; "src_addr no match")]
1669    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V4(AddressMatcher {
1670        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv4::TEST_ADDRS.subnet)),
1671        invert: false,
1672    })) => true; "dst_addr match")]
1673    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V4(AddressMatcher {
1674        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv4Addr::new([0, 0, 0, 0]), 32).unwrap())),
1675        invert: false,
1676    })) => false; "dst_addr no match")]
1677    fn ip_socket_matcher_test_v4(matcher: IpSocketMatcher<FakeDeviceClass>) -> bool {
1678        let socket = FakeIpSocket::<Ipv4, _> {
1679            src_ip: <Ipv4 as TestIpExt>::TEST_ADDRS.local_ip.get(),
1680            dst_ip: <Ipv4 as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1681            proto: FakeTcpSocket {
1682                src_port: 80,
1683                dst_port: 12345,
1684                state: TcpSocketState::Established,
1685            },
1686            cookie: 0,
1687            intf: None,
1688            marks: Marks::default(),
1689        };
1690        matcher.matches(&socket)
1691    }
1692
1693    #[test_case(IpSocketMatcher::Family(IpVersion::V4) => false; "v4 family matcher on v6 socket")]
1694    #[test_case(IpSocketMatcher::Family(IpVersion::V6) => true; "v6 family matcher on v6 socket")]
1695    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V6(AddressMatcher {
1696        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv6::TEST_ADDRS.subnet)),
1697        invert: false,
1698    })) => true; "src_addr match v6")]
1699    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V6(AddressMatcher {
1700        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv6Addr::new([0; 8]), 128).unwrap())),
1701        invert: false,
1702    })) => false; "src_addr no match v6")]
1703    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V6(AddressMatcher {
1704        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv6::TEST_ADDRS.subnet)),
1705        invert: false,
1706    })) => true; "dst_addr match v6")]
1707    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V6(AddressMatcher {
1708        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv6Addr::new([0; 8]), 128).unwrap())),
1709        invert: false,
1710    })) => false; "dst_addr no match v6")]
1711    fn ip_socket_matcher_test_v6(matcher: IpSocketMatcher<FakeDeviceClass>) -> bool {
1712        let socket = FakeIpSocket::<Ipv6, _> {
1713            src_ip: <Ipv6 as TestIpExt>::TEST_ADDRS.local_ip.get(),
1714            dst_ip: <Ipv6 as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1715            proto: FakeTcpSocket {
1716                src_port: 80,
1717                dst_port: 12345,
1718                state: TcpSocketState::Established,
1719            },
1720            cookie: 0,
1721            intf: None,
1722            marks: Marks::default(),
1723        };
1724        matcher.matches(&socket)
1725    }
1726}