1use alloc::vec::Vec;
8use core::fmt::Debug;
9use core::ops::Deref as _;
10
11use net_types::ip::Ip;
12use netstack3_base::{
13 DeviceNameMatcher, DeviceWithName, Mark, MarkDomain, MarkStorage, Marks, Matcher, SubnetMatcher,
14};
15
16use crate::internal::routing::PacketOrigin;
17use crate::RoutingTableId;
18
19pub struct RulesTable<I: Ip, D> {
21 rules: Vec<Rule<I, D>>,
23}
24
25impl<I: Ip, D> RulesTable<I, D> {
26 pub(crate) fn new(main_table_id: RoutingTableId<I, D>) -> Self {
27 Self {
30 rules: alloc::vec![Rule {
31 matcher: RuleMatcher::match_all_packets(),
32 action: RuleAction::Lookup(main_table_id)
33 }],
34 }
35 }
36
37 pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ Rule<I, D>> {
38 self.rules.iter()
39 }
40
41 #[cfg(any(test, feature = "testutils"))]
43 pub fn rules_mut(&mut self) -> &mut Vec<Rule<I, D>> {
44 &mut self.rules
45 }
46
47 pub fn replace(&mut self, new_rules: Vec<Rule<I, D>>) {
49 self.rules = new_rules;
50 }
51}
52
53pub struct Rule<I: Ip, D> {
55 pub matcher: RuleMatcher<I>,
57 pub action: RuleAction<RoutingTableId<I, D>>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum RuleAction<Lookup> {
64 Unreachable,
66 Lookup(Lookup),
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum TrafficOriginMatcher {
77 Local {
80 bound_device_matcher: Option<DeviceNameMatcher>,
82 },
83 NonLocal,
85}
86
87impl<'a, I: Ip, D: DeviceWithName> Matcher<PacketOrigin<I, &'a D>> for SubnetMatcher<I::Addr> {
88 fn matches(&self, actual: &PacketOrigin<I, &'a D>) -> bool {
89 match actual {
90 PacketOrigin::Local { bound_address, bound_device: _ } => {
91 self.required_matches(bound_address.as_deref())
92 }
93 PacketOrigin::NonLocal { source_address, incoming_device: _ } => {
94 self.matches(source_address.deref())
95 }
96 }
97 }
98}
99
100impl<'a, I: Ip, D: DeviceWithName> Matcher<PacketOrigin<I, &'a D>> for TrafficOriginMatcher {
101 fn matches(&self, actual: &PacketOrigin<I, &'a D>) -> bool {
102 match (self, actual) {
103 (
104 TrafficOriginMatcher::Local { bound_device_matcher },
105 PacketOrigin::Local { bound_address: _, bound_device },
106 ) => bound_device_matcher.required_matches(*bound_device),
107 (
108 TrafficOriginMatcher::NonLocal,
109 PacketOrigin::NonLocal { source_address: _, incoming_device: _ },
110 ) => true,
111 (TrafficOriginMatcher::Local { .. }, PacketOrigin::NonLocal { .. })
112 | (TrafficOriginMatcher::NonLocal, PacketOrigin::Local { .. }) => false,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum MarkMatcher {
120 Unmarked,
122 Marked {
124 mask: u32,
126 start: u32,
128 end: u32,
130 },
131}
132
133impl Matcher<Mark> for MarkMatcher {
134 fn matches(&self, Mark(actual): &Mark) -> bool {
135 match self {
136 MarkMatcher::Unmarked => actual.is_none(),
137 MarkMatcher::Marked { mask, start, end } => {
138 actual.is_some_and(|actual| (*start..=*end).contains(&(actual & *mask)))
139 }
140 }
141 }
142}
143
144#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
146pub struct MarkMatchers(MarkStorage<Option<MarkMatcher>>);
147
148impl MarkMatchers {
149 pub fn new(matchers: impl IntoIterator<Item = (MarkDomain, MarkMatcher)>) -> Self {
157 MarkMatchers(MarkStorage::new(matchers))
158 }
159
160 pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &Option<MarkMatcher>)> {
162 let Self(storage) = self;
163 storage.iter()
164 }
165}
166
167impl Matcher<Marks> for MarkMatchers {
168 fn matches(&self, actual: &Marks) -> bool {
169 let Self(matchers) = self;
170 matchers.zip_with(actual).all(|(_domain, matcher, actual)| matcher.matches(actual))
171 }
172}
173
174#[derive(Debug, Clone, PartialEq, Eq)]
178pub struct RuleMatcher<I: Ip> {
179 pub source_address_matcher: Option<SubnetMatcher<I::Addr>>,
185 pub traffic_origin_matcher: Option<TrafficOriginMatcher>,
188 pub mark_matchers: MarkMatchers,
190}
191
192impl<I: Ip> RuleMatcher<I> {
193 pub fn match_all_packets() -> Self {
195 RuleMatcher {
196 source_address_matcher: None,
197 traffic_origin_matcher: None,
198 mark_matchers: MarkMatchers::default(),
199 }
200 }
201}
202
203pub struct RuleInput<'a, I: Ip, D> {
205 pub(crate) packet_origin: PacketOrigin<I, &'a D>,
206 pub(crate) marks: &'a Marks,
207}
208
209impl<'a, I: Ip, D: DeviceWithName> Matcher<RuleInput<'a, I, D>> for RuleMatcher<I> {
210 fn matches(&self, actual: &RuleInput<'a, I, D>) -> bool {
211 let Self { source_address_matcher, traffic_origin_matcher, mark_matchers } = self;
212 let RuleInput { packet_origin, marks } = actual;
213 source_address_matcher.matches(packet_origin)
214 && traffic_origin_matcher.matches(packet_origin)
215 && mark_matchers.matches(marks)
216 }
217}
218
219#[cfg(test)]
220mod test {
221 use ip_test_macro::ip_test;
222 use net_types::ip::Subnet;
223 use net_types::SpecifiedAddr;
224 use netstack3_base::testutil::{FakeDeviceId, MultipleDevicesId, TestIpExt};
225 use test_case::test_case;
226
227 use super::*;
228
229 #[ip_test(I)]
230 #[test_case(None, None => true)]
231 #[test_case(None, Some(MultipleDevicesId::A) => true)]
232 #[test_case(Some("A"), None => false)]
233 #[test_case(Some("A"), Some(MultipleDevicesId::A) => true)]
234 #[test_case(Some("A"), Some(MultipleDevicesId::B) => false)]
235 fn rule_matcher_matches_device_name<I: TestIpExt>(
236 device_name: Option<&str>,
237 bound_device: Option<MultipleDevicesId>,
238 ) -> bool {
239 let matcher = RuleMatcher::<I> {
240 traffic_origin_matcher: Some(TrafficOriginMatcher::Local {
241 bound_device_matcher: device_name.map(|name| DeviceNameMatcher(name.into())),
242 }),
243 ..RuleMatcher::match_all_packets()
244 };
245 let input = RuleInput {
246 packet_origin: PacketOrigin::Local {
247 bound_address: None,
248 bound_device: bound_device.as_ref(),
249 },
250 marks: &Default::default(),
251 };
252 matcher.matches(&input)
253 }
254
255 #[ip_test(I)]
256 #[test_case(None, None => true)]
257 #[test_case(None, Some(I::LOOPBACK_ADDRESS) => true)]
258 #[test_case(
259 Some(<I as TestIpExt>::TEST_ADDRS.subnet),
260 None => false)]
261 #[test_case(
262 Some(<I as TestIpExt>::TEST_ADDRS.subnet),
263 Some(<I as TestIpExt>::TEST_ADDRS.local_ip) => true)]
264 #[test_case(
265 Some(<I as TestIpExt>::TEST_ADDRS.subnet),
266 Some(<I as TestIpExt>::get_other_remote_ip_address(1)) => false)]
267 fn rule_matcher_matches_local_addr<I: TestIpExt>(
268 source_address_subnet: Option<Subnet<I::Addr>>,
269 bound_address: Option<SpecifiedAddr<I::Addr>>,
270 ) -> bool {
271 let matcher = RuleMatcher::<I> {
272 source_address_matcher: source_address_subnet.map(SubnetMatcher),
273 ..RuleMatcher::match_all_packets()
274 };
275 let marks = Default::default();
276 let input = RuleInput::<'_, _, FakeDeviceId> {
277 packet_origin: PacketOrigin::Local { bound_address, bound_device: None },
278 marks: &marks,
279 };
280 matcher.matches(&input)
281 }
282
283 #[ip_test(I)]
284 #[test_case(None, PacketOrigin::Local {
285 bound_address: None,
286 bound_device: None
287 } => true)]
288 #[test_case(None, PacketOrigin::NonLocal {
289 source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
290 incoming_device: &FakeDeviceId
291 } => true)]
292 #[test_case(Some(TrafficOriginMatcher::Local {
293 bound_device_matcher: None
294 }), PacketOrigin::Local {
295 bound_address: None,
296 bound_device: None
297 } => true)]
298 #[test_case(Some(TrafficOriginMatcher::NonLocal),
299 PacketOrigin::NonLocal {
300 source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
301 incoming_device: &FakeDeviceId
302 } => true)]
303 #[test_case(Some(TrafficOriginMatcher::Local { bound_device_matcher: None }),
304 PacketOrigin::NonLocal {
305 source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
306 incoming_device: &FakeDeviceId
307 } => false)]
308 #[test_case(Some(TrafficOriginMatcher::NonLocal),
309 PacketOrigin::Local {
310 bound_address: None,
311 bound_device: None
312 } => false)]
313 fn rule_matcher_matches_locally_generated<I: TestIpExt>(
314 traffic_origin_matcher: Option<TrafficOriginMatcher>,
315 packet_origin: PacketOrigin<I, &'static FakeDeviceId>,
316 ) -> bool {
317 let matcher =
318 RuleMatcher::<I> { traffic_origin_matcher, ..RuleMatcher::match_all_packets() };
319 let marks = Default::default();
320 let input = RuleInput::<'_, _, FakeDeviceId> { packet_origin, marks: &marks };
321 matcher.matches(&input)
322 }
323
324 #[ip_test(I)]
325 #[test_case::test_matrix(
326 [
327 None,
328 Some(<I as TestIpExt>::TEST_ADDRS.local_ip),
329 Some(<I as TestIpExt>::get_other_remote_ip_address(1))
330 ],
331 [
332 None,
333 Some(&MultipleDevicesId::A),
334 Some(&MultipleDevicesId::B),
335 Some(&MultipleDevicesId::C),
336 ],
337 [true, false]
338 )]
339 fn rule_matcher_matches_multiple_conditions<I: TestIpExt>(
340 ip: Option<SpecifiedAddr<I::Addr>>,
341 device: Option<&'static MultipleDevicesId>,
342 locally_generated: bool,
343 ) {
344 let matcher = RuleMatcher::<I> {
345 source_address_matcher: Some(SubnetMatcher(I::TEST_ADDRS.subnet)),
346 traffic_origin_matcher: Some(TrafficOriginMatcher::Local {
347 bound_device_matcher: Some(DeviceNameMatcher("A".into())),
348 }),
349 ..RuleMatcher::match_all_packets()
350 };
351
352 let packet_origin = if locally_generated {
353 PacketOrigin::Local { bound_address: ip, bound_device: device }
354 } else {
355 let (Some(source_address), Some(incoming_device)) = (ip, device) else {
356 return;
357 };
358 PacketOrigin::NonLocal { source_address, incoming_device }
359 };
360
361 let input = RuleInput { packet_origin, marks: &Default::default() };
362
363 if ip == Some(I::TEST_ADDRS.local_ip)
364 && (device == Some(&MultipleDevicesId::A))
365 && locally_generated
366 {
367 assert!(matcher.matches(&input))
368 } else {
369 assert!(!matcher.matches(&input))
370 }
371 }
372
373 #[test_case(MarkMatcher::Unmarked, Mark(None) => true)]
374 #[test_case(MarkMatcher::Unmarked, Mark(Some(0)) => false)]
375 #[test_case(MarkMatcher::Marked {
376 mask: 1,
377 start: 0,
378 end: 0,
379 }, Mark(None) => false)]
380 #[test_case(MarkMatcher::Marked {
381 mask: 1,
382 start: 0,
383 end: 0,
384 }, Mark(Some(0)) => true)]
385 #[test_case(MarkMatcher::Marked {
386 mask: 1,
387 start: 0,
388 end: 0,
389 }, Mark(Some(1)) => false)]
390 #[test_case(MarkMatcher::Marked {
391 mask: 1,
392 start: 0,
393 end: 0,
394 }, Mark(Some(2)) => true)]
395 #[test_case(MarkMatcher::Marked {
396 mask: 1,
397 start: 0,
398 end: 0,
399 }, Mark(Some(3)) => false)]
400 fn mark_matcher(matcher: MarkMatcher, mark: Mark) -> bool {
401 matcher.matches(&mark)
402 }
403
404 #[test_case(
405 MarkMatchers::new(
406 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
407 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
408 ),
409 Marks::new([]) => true
410 )]
411 #[test_case(
412 MarkMatchers::new(
413 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
414 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
415 ),
416 Marks::new([(MarkDomain::Mark1, 1)]) => false
417 )]
418 #[test_case(
419 MarkMatchers::new(
420 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
421 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
422 ),
423 Marks::new([(MarkDomain::Mark2, 1)]) => false
424 )]
425 #[test_case(
426 MarkMatchers::new(
427 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
428 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
429 ),
430 Marks::new([
431 (MarkDomain::Mark1, 1),
432 (MarkDomain::Mark2, 1),
433 ]) => false
434 )]
435 fn mark_matchers(matchers: MarkMatchers, marks: Marks) -> bool {
436 matchers.matches(&marks)
437 }
438}