1use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12 MatcherBindingsTypes, RngContext, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::BindingsPacketMatcher;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21pub trait FilterBindingsTypes:
23 InstantBindingsTypes
24 + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
25 + TimerBindingsTypes
26 + 'static
27{
28}
29
30impl<
31 BT: InstantBindingsTypes
32 + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
33 + TimerBindingsTypes
34 + 'static,
35> FilterBindingsTypes for BT
36{
37}
38
39pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
41impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
42
43pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
52 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
53{
54 type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
57
58 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
60 &mut self,
61 cb: F,
62 ) -> O {
63 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
64 }
65
66 fn with_filter_state_and_nat_ctx<
69 O,
70 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
71 >(
72 &mut self,
73 cb: F,
74 ) -> O;
75}
76
77pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
79 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
80{
81 fn get_local_addr_for_remote(
83 &mut self,
84 device_id: &Self::DeviceId,
85 remote: Option<SpecifiedAddr<I::Addr>>,
86 ) -> Option<Self::AddressId>;
87
88 fn get_address_id(
91 &mut self,
92 device_id: &Self::DeviceId,
93 addr: IpDeviceAddr<I::Addr>,
94 ) -> Option<Self::AddressId>;
95}
96
97pub trait FilterContext<BT: FilterBindingsTypes>:
100 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
101 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
102{
103 fn with_all_filter_state_mut<
105 O,
106 F: FnOnce(
107 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
108 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
109 ) -> O,
110 >(
111 &mut self,
112 cb: F,
113 ) -> O;
114}
115
116#[derive(Copy, Clone, Debug, Eq, PartialEq)]
118pub enum SocketEgressFilterResult {
119 Pass {
121 congestion: bool,
123 },
124
125 Drop {
127 congestion: bool,
129 },
130}
131
132#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134pub enum SocketIngressFilterResult {
135 Accept,
137
138 Drop,
140}
141
142pub trait SocketOpsFilter<D> {
144 fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
146 &self,
147 packet: &P,
148 device: &D,
149 cookie: SocketCookie,
150 marks: &Marks,
151 ) -> SocketEgressFilterResult;
152
153 fn on_ingress(
155 &self,
156 ip_version: IpVersion,
157 packet: FragmentedByteSlice<'_, &[u8]>,
158 device: &D,
159 cookie: SocketCookie,
160 marks: &Marks,
161 ) -> SocketIngressFilterResult;
162}
163
164pub trait SocketOpsFilterBindingContext<D>: TxMetadataBindingsTypes {
166 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
168}
169
170#[cfg(any(test, feature = "testutils"))]
171impl<
172 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
173 Event: Debug + 'static,
174 State: 'static,
175 FrameMeta: 'static,
176 D,
177> SocketOpsFilterBindingContext<D>
178 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
179{
180 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
181 crate::testutil::NoOpSocketOpsFilter
182 }
183}
184
185#[cfg(test)]
186pub(crate) mod testutil {
187 use alloc::sync::{Arc, Weak};
188 use alloc::vec::Vec;
189 use core::hash::{Hash, Hasher};
190 use core::ops::Deref;
191 use core::sync::atomic::AtomicUsize;
192 use core::time::Duration;
193
194 use derivative::Derivative;
195 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
196 use netstack3_base::testutil::{
197 FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
198 FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
199 };
200 use netstack3_base::{
201 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
202 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
203 };
204 use netstack3_hashmap::HashMap;
205
206 use super::*;
207 use crate::conntrack;
208 use crate::logic::FilterTimerId;
209 use crate::logic::nat::NatConfig;
210 use crate::state::validation::ValidRoutines;
211 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
212
213 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
214
215 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
216
217 #[derive(Debug)]
218 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
219 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
220 );
221
222 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
223 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
224
225 #[derive(Clone, Debug)]
226 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
227 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
228 );
229
230 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
231 fn eq(&self, other: &Self) -> bool {
232 let Self(lhs) = self;
233 let Self(rhs) = other;
234 Weak::ptr_eq(lhs, rhs)
235 }
236 }
237
238 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
239
240 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
241 fn hash<H: Hasher>(&self, state: &mut H) {
242 let Self(this) = self;
243 this.as_ptr().hash(state)
244 }
245 }
246
247 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
248 type Strong = FakeAddressId<I>;
249
250 fn upgrade(&self) -> Option<Self::Strong> {
251 let Self(inner) = self;
252 inner.upgrade().map(FakeAddressId)
253 }
254
255 fn is_assigned(&self) -> bool {
256 let Self(inner) = self;
257 inner.strong_count() != 0
258 }
259 }
260
261 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
262 fn record<Inspector: netstack3_base::Inspector>(
263 &self,
264 _name: &str,
265 _inspector: &mut Inspector,
266 ) {
267 unimplemented!()
268 }
269 }
270
271 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
272 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
273
274 fn deref(&self) -> &Self::Target {
275 let Self(inner) = self;
276 inner.deref()
277 }
278 }
279
280 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
281 type Weak = FakeWeakAddressId<I>;
282
283 fn downgrade(&self) -> Self::Weak {
284 let Self(inner) = self;
285 FakeWeakAddressId(Arc::downgrade(inner))
286 }
287
288 fn addr(&self) -> IpDeviceAddr<I::Addr> {
289 let Self(inner) = self;
290
291 #[derive(GenericOverIp)]
292 #[generic_over_ip(I, Ip)]
293 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
294 I::map_ip(
295 WrapIn(inner.addr()),
296 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
297 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
298 )
299 }
300
301 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
302 let Self(inner) = self;
303 **inner
304 }
305 }
306
307 pub struct FakeCtx<I: TestIpExt> {
308 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
309 nat: FakeNatCtx<I>,
310 }
311
312 #[derive(Derivative)]
313 #[derivative(Default(bound = ""))]
314 pub struct FakeNatCtx<I: TestIpExt> {
315 pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
316 }
317
318 impl<I: TestIpExt> FakeCtx<I> {
319 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
320 Self {
321 state: State {
322 installed_routines: ValidRoutines::default(),
323 uninstalled_routines: Vec::default(),
324 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
325 nat_installed: OneWayBoolean::default(),
326 },
327 nat: FakeNatCtx::default(),
328 }
329 }
330
331 pub fn with_ip_routines(
332 bindings_ctx: &mut FakeBindingsCtx<I>,
333 routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
334 ) -> Self {
335 let (installed_routines, uninstalled_routines) =
336 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
337 .expect("invalid state");
338 Self {
339 state: State {
340 installed_routines,
341 uninstalled_routines,
342 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
343 nat_installed: OneWayBoolean::default(),
344 },
345 nat: FakeNatCtx::default(),
346 }
347 }
348
349 pub fn with_nat_routines_and_device_addrs(
350 bindings_ctx: &mut FakeBindingsCtx<I>,
351 routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
352 device_addrs: impl IntoIterator<
353 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
354 >,
355 ) -> Self {
356 let (installed_routines, uninstalled_routines) =
357 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
358 .expect("invalid state");
359 Self {
360 state: State {
361 installed_routines,
362 uninstalled_routines,
363 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
364 nat_installed: OneWayBoolean::TRUE,
365 },
366 nat: FakeNatCtx {
367 device_addrs: device_addrs
368 .into_iter()
369 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
370 .collect(),
371 },
372 }
373 }
374
375 pub fn conntrack(
376 &mut self,
377 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
378 &self.state.conntrack
379 }
380 }
381
382 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
383 type DeviceId = FakeMatcherDeviceId;
384 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
385 }
386
387 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
388 type AddressId = FakeAddressId<I>;
389 type WeakAddressId = FakeWeakAddressId<I>;
390 }
391
392 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
393 type NatCtx<'a> = FakeNatCtx<I>;
394
395 fn with_filter_state_and_nat_ctx<
396 O,
397 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
398 >(
399 &mut self,
400 cb: F,
401 ) -> O {
402 let Self { state, nat } = self;
403 cb(state, nat)
404 }
405 }
406
407 impl<I: TestIpExt> FakeNatCtx<I> {
408 pub fn new(
409 device_addrs: impl IntoIterator<
410 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
411 >,
412 ) -> Self {
413 Self {
414 device_addrs: device_addrs
415 .into_iter()
416 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
417 .collect(),
418 }
419 }
420 }
421
422 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
423 type DeviceId = FakeMatcherDeviceId;
424 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
425 }
426
427 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
428 type AddressId = FakeAddressId<I>;
429 type WeakAddressId = FakeWeakAddressId<I>;
430 }
431
432 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
433 fn get_local_addr_for_remote(
434 &mut self,
435 device_id: &Self::DeviceId,
436 _remote: Option<SpecifiedAddr<I::Addr>>,
437 ) -> Option<Self::AddressId> {
438 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
439 Some(FakeAddressId(primary.clone()))
440 }
441
442 fn get_address_id(
443 &mut self,
444 device_id: &Self::DeviceId,
445 addr: IpDeviceAddr<I::Addr>,
446 ) -> Option<Self::AddressId> {
447 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
448 let id = FakeAddressId(id.clone());
449 if id.addr() == addr { Some(id) } else { None }
450 }
451 }
452
453 pub struct FakeBindingsCtx<I: Ip> {
454 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
455 pub rng: FakeCryptoRng,
456 }
457
458 impl<I: Ip> FakeBindingsCtx<I> {
459 pub(crate) fn new() -> Self {
460 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
461 }
462
463 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
464 self.timer_ctx.instant.sleep(time_elapsed)
465 }
466 }
467
468 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
469 type Instant = FakeInstant;
470 type AtomicInstant = FakeAtomicInstant;
471 }
472
473 #[derive(Debug)]
474 pub struct FakeBindingsPacketMatcher {
475 num_calls: AtomicUsize,
476 result: bool,
477 }
478
479 impl FakeBindingsPacketMatcher {
480 pub fn new(result: bool) -> Arc<Self> {
481 Arc::new(Self { num_calls: AtomicUsize::new(0), result })
482 }
483
484 pub fn num_calls(&self) -> usize {
485 self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
486 }
487 }
488
489 impl BindingsPacketMatcher for FakeBindingsPacketMatcher {
490 fn matches<I: FilterIpExt, P: IpPacket<I>>(&self, _packet: &P) -> bool {
491 let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
492 self.result
493 }
494 }
495
496 impl InspectableValue for FakeBindingsPacketMatcher {
497 fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
498 unimplemented!()
499 }
500 }
501
502 impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
503 type DeviceClass = FakeDeviceClass;
504 type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
505 }
506
507 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
508 fn now(&self) -> Self::Instant {
509 self.timer_ctx.now()
510 }
511 }
512
513 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
514 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
515 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
516 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
517 }
518
519 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
520 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
521 self.timer_ctx.new_timer(id)
522 }
523
524 fn schedule_timer_instant(
525 &mut self,
526 time: Self::Instant,
527 timer: &mut Self::Timer,
528 ) -> Option<Self::Instant> {
529 self.timer_ctx.schedule_timer_instant(time, timer)
530 }
531
532 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
533 self.timer_ctx.cancel_timer(timer)
534 }
535
536 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
537 self.timer_ctx.scheduled_instant(timer)
538 }
539
540 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
541 self.timer_ctx.unique_timer_id(timer)
542 }
543 }
544
545 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
546 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
547 &self,
548 f: F,
549 ) -> O {
550 f(&self.timer_ctx)
551 }
552
553 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
554 &mut self,
555 f: F,
556 ) -> O {
557 f(&mut self.timer_ctx)
558 }
559 }
560
561 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
562 type Rng<'a>
563 = FakeCryptoRng
564 where
565 Self: 'a;
566
567 fn rng(&mut self) -> Self::Rng<'_> {
568 self.rng.clone()
569 }
570 }
571}