netstack3_ip/routing/
rules.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! IP routing rules.
6
7use alloc::vec::Vec;
8use core::fmt::Debug;
9use core::ops::Deref as _;
10
11use net_types::ip::Ip;
12use netstack3_base::{
13    DeviceNameMatcher, DeviceWithName, Mark, MarkDomain, MarkStorage, Marks, Matcher, SubnetMatcher,
14};
15
16use crate::internal::routing::PacketOrigin;
17use crate::RoutingTableId;
18
19/// Table that contains routing rules.
20pub struct RulesTable<I: Ip, D> {
21    /// Rules of the table.
22    rules: Vec<Rule<I, D>>,
23}
24
25impl<I: Ip, D> RulesTable<I, D> {
26    pub(crate) fn new(main_table_id: RoutingTableId<I, D>) -> Self {
27        // TODO(https://fxbug.dev/355059790): If bindings is installing the main table, we should
28        // also let the bindings install this default rule.
29        Self {
30            rules: alloc::vec![Rule {
31                matcher: RuleMatcher::match_all_packets(),
32                action: RuleAction::Lookup(main_table_id)
33            }],
34        }
35    }
36
37    pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ Rule<I, D>> {
38        self.rules.iter()
39    }
40
41    /// Gets the mutable reference to the rules vector.
42    #[cfg(any(test, feature = "testutils"))]
43    pub fn rules_mut(&mut self) -> &mut Vec<Rule<I, D>> {
44        &mut self.rules
45    }
46
47    /// Replaces the rules inside this table.
48    pub fn replace(&mut self, new_rules: Vec<Rule<I, D>>) {
49        self.rules = new_rules;
50    }
51}
52
53/// A routing rule.
54pub struct Rule<I: Ip, D> {
55    /// The matcher of the rule.
56    pub matcher: RuleMatcher<I>,
57    /// The action of the rule.
58    pub action: RuleAction<RoutingTableId<I, D>>,
59}
60
61/// The action part of a [`Rule`].
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum RuleAction<Lookup> {
64    /// Will resolve to unreachable.
65    Unreachable,
66    /// Lookup in a routing table.
67    Lookup(Lookup),
68}
69
70/// Matches with [`PacketOrigin`].
71///
72/// Note that this matcher doesn't specify the source address/bound address like [`PacketOrigin`]
73/// because the user can specify a source address matcher without specifying the direction of the
74/// traffic.
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum TrafficOriginMatcher {
77    /// This only matches packets that are generated locally; the optional interface matcher
78    /// can be used to match what device is bound to by `SO_BINDTODEVICE`.
79    Local {
80        /// The matcher for the bound device.
81        bound_device_matcher: Option<DeviceNameMatcher>,
82    },
83    /// This only matches non-local packets. The packets must be received from the network.
84    NonLocal,
85}
86
87impl<'a, I: Ip, D: DeviceWithName> Matcher<PacketOrigin<I, &'a D>> for SubnetMatcher<I::Addr> {
88    fn matches(&self, actual: &PacketOrigin<I, &'a D>) -> bool {
89        match actual {
90            PacketOrigin::Local { bound_address, bound_device: _ } => {
91                self.required_matches(bound_address.as_deref())
92            }
93            PacketOrigin::NonLocal { source_address, incoming_device: _ } => {
94                self.matches(source_address.deref())
95            }
96        }
97    }
98}
99
100impl<'a, I: Ip, D: DeviceWithName> Matcher<PacketOrigin<I, &'a D>> for TrafficOriginMatcher {
101    fn matches(&self, actual: &PacketOrigin<I, &'a D>) -> bool {
102        match (self, actual) {
103            (
104                TrafficOriginMatcher::Local { bound_device_matcher },
105                PacketOrigin::Local { bound_address: _, bound_device },
106            ) => bound_device_matcher.required_matches(*bound_device),
107            (
108                TrafficOriginMatcher::NonLocal,
109                PacketOrigin::NonLocal { source_address: _, incoming_device: _ },
110            ) => true,
111            (TrafficOriginMatcher::Local { .. }, PacketOrigin::NonLocal { .. })
112            | (TrafficOriginMatcher::NonLocal, PacketOrigin::Local { .. }) => false,
113        }
114    }
115}
116
117/// A matcher to the socket mark.
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum MarkMatcher {
120    /// Matches a packet if it is unmarked.
121    Unmarked,
122    /// The packet carries a mark that is in the range after masking.
123    Marked {
124        /// The mask to apply.
125        mask: u32,
126        /// Start of the range, inclusive.
127        start: u32,
128        /// End of the range, inclusive.
129        end: u32,
130    },
131}
132
133impl Matcher<Mark> for MarkMatcher {
134    fn matches(&self, Mark(actual): &Mark) -> bool {
135        match self {
136            MarkMatcher::Unmarked => actual.is_none(),
137            MarkMatcher::Marked { mask, start, end } => {
138                actual.is_some_and(|actual| (*start..=*end).contains(&(actual & *mask)))
139            }
140        }
141    }
142}
143
144/// The 2 mark matchers a rule can specify. All non-none markers must match.
145#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
146pub struct MarkMatchers(MarkStorage<Option<MarkMatcher>>);
147
148impl MarkMatchers {
149    /// Creates [`MarkMatcher`]s from an iterator of `(MarkDomain, MarkMatcher)`.
150    ///
151    /// An unspecified domain will not have a matcher.
152    ///
153    /// # Panics
154    ///
155    /// Panics if the same domain is specified more than once.
156    pub fn new(matchers: impl IntoIterator<Item = (MarkDomain, MarkMatcher)>) -> Self {
157        MarkMatchers(MarkStorage::new(matchers))
158    }
159
160    /// Returns an iterator over the mark matchers of all domains.
161    pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &Option<MarkMatcher>)> {
162        let Self(storage) = self;
163        storage.iter()
164    }
165}
166
167impl Matcher<Marks> for MarkMatchers {
168    fn matches(&self, actual: &Marks) -> bool {
169        let Self(matchers) = self;
170        matchers.zip_with(actual).all(|(_domain, matcher, actual)| matcher.matches(actual))
171    }
172}
173
174/// Contains traffic matchers for a given rule.
175///
176/// `None` fields match all packets.
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub struct RuleMatcher<I: Ip> {
179    /// Matches on [`PacketOrigin`]'s bound address for a locally generated packet or the source
180    /// address of an incoming packet.
181    ///
182    /// Matches whether the source address of the packet is from the subnet. If the matcher is
183    /// specified but the source address is not specified, it resolves to not a match.
184    pub source_address_matcher: Option<SubnetMatcher<I::Addr>>,
185    /// Matches on [`PacketOrigin`]'s bound device for a locally generated packets or the receiving
186    /// device of an incoming packet.
187    pub traffic_origin_matcher: Option<TrafficOriginMatcher>,
188    /// Matches on [`RuleInput`]'s marks.
189    pub mark_matchers: MarkMatchers,
190}
191
192impl<I: Ip> RuleMatcher<I> {
193    /// Creates a rule matcher that matches all packets.
194    pub fn match_all_packets() -> Self {
195        RuleMatcher {
196            source_address_matcher: None,
197            traffic_origin_matcher: None,
198            mark_matchers: MarkMatchers::default(),
199        }
200    }
201}
202
203/// Packet properties used as input for the rules engine.
204pub struct RuleInput<'a, I: Ip, D> {
205    pub(crate) packet_origin: PacketOrigin<I, &'a D>,
206    pub(crate) marks: &'a Marks,
207}
208
209impl<'a, I: Ip, D: DeviceWithName> Matcher<RuleInput<'a, I, D>> for RuleMatcher<I> {
210    fn matches(&self, actual: &RuleInput<'a, I, D>) -> bool {
211        let Self { source_address_matcher, traffic_origin_matcher, mark_matchers } = self;
212        let RuleInput { packet_origin, marks } = actual;
213        source_address_matcher.matches(packet_origin)
214            && traffic_origin_matcher.matches(packet_origin)
215            && mark_matchers.matches(marks)
216    }
217}
218
219#[cfg(test)]
220mod test {
221    use ip_test_macro::ip_test;
222    use net_types::ip::Subnet;
223    use net_types::SpecifiedAddr;
224    use netstack3_base::testutil::{FakeDeviceId, MultipleDevicesId, TestIpExt};
225    use test_case::test_case;
226
227    use super::*;
228
229    #[ip_test(I)]
230    #[test_case(None, None => true)]
231    #[test_case(None, Some(MultipleDevicesId::A) => true)]
232    #[test_case(Some("A"), None => false)]
233    #[test_case(Some("A"), Some(MultipleDevicesId::A) => true)]
234    #[test_case(Some("A"), Some(MultipleDevicesId::B) => false)]
235    fn rule_matcher_matches_device_name<I: TestIpExt>(
236        device_name: Option<&str>,
237        bound_device: Option<MultipleDevicesId>,
238    ) -> bool {
239        let matcher = RuleMatcher::<I> {
240            traffic_origin_matcher: Some(TrafficOriginMatcher::Local {
241                bound_device_matcher: device_name.map(|name| DeviceNameMatcher(name.into())),
242            }),
243            ..RuleMatcher::match_all_packets()
244        };
245        let input = RuleInput {
246            packet_origin: PacketOrigin::Local {
247                bound_address: None,
248                bound_device: bound_device.as_ref(),
249            },
250            marks: &Default::default(),
251        };
252        matcher.matches(&input)
253    }
254
255    #[ip_test(I)]
256    #[test_case(None, None => true)]
257    #[test_case(None, Some(I::LOOPBACK_ADDRESS) => true)]
258    #[test_case(
259        Some(<I as TestIpExt>::TEST_ADDRS.subnet),
260        None => false)]
261    #[test_case(
262        Some(<I as TestIpExt>::TEST_ADDRS.subnet),
263        Some(<I as TestIpExt>::TEST_ADDRS.local_ip) => true)]
264    #[test_case(
265        Some(<I as TestIpExt>::TEST_ADDRS.subnet),
266        Some(<I as TestIpExt>::get_other_remote_ip_address(1)) => false)]
267    fn rule_matcher_matches_local_addr<I: TestIpExt>(
268        source_address_subnet: Option<Subnet<I::Addr>>,
269        bound_address: Option<SpecifiedAddr<I::Addr>>,
270    ) -> bool {
271        let matcher = RuleMatcher::<I> {
272            source_address_matcher: source_address_subnet.map(SubnetMatcher),
273            ..RuleMatcher::match_all_packets()
274        };
275        let marks = Default::default();
276        let input = RuleInput::<'_, _, FakeDeviceId> {
277            packet_origin: PacketOrigin::Local { bound_address, bound_device: None },
278            marks: &marks,
279        };
280        matcher.matches(&input)
281    }
282
283    #[ip_test(I)]
284    #[test_case(None, PacketOrigin::Local {
285         bound_address: None,
286         bound_device: None
287    } => true)]
288    #[test_case(None, PacketOrigin::NonLocal {
289        source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
290        incoming_device: &FakeDeviceId
291    } => true)]
292    #[test_case(Some(TrafficOriginMatcher::Local {
293        bound_device_matcher: None
294    }), PacketOrigin::Local {
295        bound_address: None,
296        bound_device: None
297    } => true)]
298    #[test_case(Some(TrafficOriginMatcher::NonLocal),
299        PacketOrigin::NonLocal {
300            source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
301            incoming_device: &FakeDeviceId
302        } => true)]
303    #[test_case(Some(TrafficOriginMatcher::Local { bound_device_matcher: None }),
304        PacketOrigin::NonLocal {
305            source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
306            incoming_device: &FakeDeviceId
307        }  => false)]
308    #[test_case(Some(TrafficOriginMatcher::NonLocal),
309        PacketOrigin::Local {
310            bound_address: None,
311            bound_device: None
312        } => false)]
313    fn rule_matcher_matches_locally_generated<I: TestIpExt>(
314        traffic_origin_matcher: Option<TrafficOriginMatcher>,
315        packet_origin: PacketOrigin<I, &'static FakeDeviceId>,
316    ) -> bool {
317        let matcher =
318            RuleMatcher::<I> { traffic_origin_matcher, ..RuleMatcher::match_all_packets() };
319        let marks = Default::default();
320        let input = RuleInput::<'_, _, FakeDeviceId> { packet_origin, marks: &marks };
321        matcher.matches(&input)
322    }
323
324    #[ip_test(I)]
325    #[test_case::test_matrix(
326            [
327                None,
328                Some(<I as TestIpExt>::TEST_ADDRS.local_ip),
329                Some(<I as TestIpExt>::get_other_remote_ip_address(1))
330            ],
331            [
332                None,
333                Some(&MultipleDevicesId::A),
334                Some(&MultipleDevicesId::B),
335                Some(&MultipleDevicesId::C),
336            ],
337            [true, false]
338        )]
339    fn rule_matcher_matches_multiple_conditions<I: TestIpExt>(
340        ip: Option<SpecifiedAddr<I::Addr>>,
341        device: Option<&'static MultipleDevicesId>,
342        locally_generated: bool,
343    ) {
344        let matcher = RuleMatcher::<I> {
345            source_address_matcher: Some(SubnetMatcher(I::TEST_ADDRS.subnet)),
346            traffic_origin_matcher: Some(TrafficOriginMatcher::Local {
347                bound_device_matcher: Some(DeviceNameMatcher("A".into())),
348            }),
349            ..RuleMatcher::match_all_packets()
350        };
351
352        let packet_origin = if locally_generated {
353            PacketOrigin::Local { bound_address: ip, bound_device: device }
354        } else {
355            let (Some(source_address), Some(incoming_device)) = (ip, device) else {
356                return;
357            };
358            PacketOrigin::NonLocal { source_address, incoming_device }
359        };
360
361        let input = RuleInput { packet_origin, marks: &Default::default() };
362
363        if ip == Some(I::TEST_ADDRS.local_ip)
364            && (device == Some(&MultipleDevicesId::A))
365            && locally_generated
366        {
367            assert!(matcher.matches(&input))
368        } else {
369            assert!(!matcher.matches(&input))
370        }
371    }
372
373    #[test_case(MarkMatcher::Unmarked, Mark(None) => true)]
374    #[test_case(MarkMatcher::Unmarked, Mark(Some(0)) => false)]
375    #[test_case(MarkMatcher::Marked {
376        mask: 1,
377        start: 0,
378        end: 0,
379    }, Mark(None) => false)]
380    #[test_case(MarkMatcher::Marked {
381        mask: 1,
382        start: 0,
383        end: 0,
384    }, Mark(Some(0)) => true)]
385    #[test_case(MarkMatcher::Marked {
386        mask: 1,
387        start: 0,
388        end: 0,
389    }, Mark(Some(1)) => false)]
390    #[test_case(MarkMatcher::Marked {
391        mask: 1,
392        start: 0,
393        end: 0,
394    }, Mark(Some(2)) => true)]
395    #[test_case(MarkMatcher::Marked {
396        mask: 1,
397        start: 0,
398        end: 0,
399    }, Mark(Some(3)) => false)]
400    fn mark_matcher(matcher: MarkMatcher, mark: Mark) -> bool {
401        matcher.matches(&mark)
402    }
403
404    #[test_case(
405        MarkMatchers::new(
406            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
407            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
408        ),
409        Marks::new([]) => true
410    )]
411    #[test_case(
412        MarkMatchers::new(
413            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
414            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
415        ),
416        Marks::new([(MarkDomain::Mark1, 1)]) => false
417    )]
418    #[test_case(
419        MarkMatchers::new(
420            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
421            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
422        ),
423        Marks::new([(MarkDomain::Mark2, 1)]) => false
424    )]
425    #[test_case(
426        MarkMatchers::new(
427            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
428            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
429        ),
430        Marks::new([
431            (MarkDomain::Mark1, 1),
432            (MarkDomain::Mark2, 1),
433        ]) => false
434    )]
435    fn mark_matchers(matchers: MarkMatchers, marks: Marks) -> bool {
436        matchers.matches(&marks)
437    }
438}