1use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9
10pub use netstack3_base::Marks;
11pub use netstack3_base::socket::{EitherIpProto, SocketInfo};
12use netstack3_base::{
13 InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext,
14 MatcherBindingsTypes, RngContext, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
15};
16use packet::FragmentedByteSlice;
17use packet_formats::ip::IpExt;
18
19use crate::FilterIpExt;
20use crate::matchers::BindingsPacketMatcher;
21use crate::packets::FilterIpPacket;
22use crate::state::State;
23
24pub trait FilterBindingsTypes:
26 InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static
27{
28}
29
30impl<BT: InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static>
31 FilterBindingsTypes for BT
32{
33}
34
35pub trait FilterBindingsContext<D>:
37 TimerContext + RngContext + FilterBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher<D>>
38{
39}
40impl<D, BC> FilterBindingsContext<D> for BC
41where
42 BC: TimerContext + RngContext + FilterBindingsTypes,
43 BC::BindingsPacketMatcher: BindingsPacketMatcher<D>,
44{
45}
46
47pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
56 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
57{
58 type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
61
62 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
64 &mut self,
65 cb: F,
66 ) -> O {
67 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
68 }
69
70 fn with_filter_state_and_nat_ctx<
73 O,
74 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
75 >(
76 &mut self,
77 cb: F,
78 ) -> O;
79}
80
81pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
83 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
84{
85 fn get_local_addr_for_remote(
87 &mut self,
88 device_id: &Self::DeviceId,
89 remote: Option<SpecifiedAddr<I::Addr>>,
90 ) -> Option<Self::AddressId>;
91
92 fn get_address_id(
95 &mut self,
96 device_id: &Self::DeviceId,
97 addr: IpDeviceAddr<I::Addr>,
98 ) -> Option<Self::AddressId>;
99}
100
101pub trait FilterContext<BT: FilterBindingsTypes>:
104 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
105 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
106{
107 fn with_all_filter_state_mut<
109 O,
110 F: FnOnce(
111 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
112 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
113 ) -> O,
114 >(
115 &mut self,
116 cb: F,
117 ) -> O;
118}
119
120#[derive(Copy, Clone, Debug, Eq, PartialEq)]
122pub enum SocketEgressFilterResult {
123 Pass {
125 congestion: bool,
127 },
128
129 Drop {
131 congestion: bool,
133 },
134}
135
136#[derive(Copy, Clone, Debug, Eq, PartialEq)]
138pub enum SocketIngressFilterResult {
139 Accept,
141
142 Drop,
144}
145
146pub trait SocketOpsFilter<D> {
148 fn on_egress<I: FilterIpExt, P: FilterIpPacket<I>>(
150 &self,
151 packet: &P,
152 device: &D,
153 socket_info: SocketInfo,
154 marks: &Marks,
155 ) -> SocketEgressFilterResult;
156
157 fn on_ingress(
159 &self,
160 ip_version: IpVersion,
161 packet: FragmentedByteSlice<'_, &[u8]>,
162 device: &D,
163 socket_info: SocketInfo,
164 marks: &Marks,
165 ) -> SocketIngressFilterResult;
166}
167
168pub trait SocketOpsFilterBindingContext<D>: TxMetadataBindingsTypes {
170 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
172}
173
174#[cfg(any(test, feature = "testutils"))]
175impl<
176 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
177 Event: Debug + 'static,
178 State: 'static,
179 FrameMeta: 'static,
180 D,
181> SocketOpsFilterBindingContext<D>
182 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
183{
184 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
185 crate::testutil::NoOpSocketOpsFilter
186 }
187}
188
189#[cfg(test)]
190pub(crate) mod testutil {
191 use alloc::sync::{Arc, Weak};
192 use alloc::vec::Vec;
193 use core::hash::{Hash, Hasher};
194 use core::ops::Deref;
195 use core::sync::atomic::AtomicUsize;
196 use core::time::Duration;
197
198 use derivative::Derivative;
199 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
200 use netstack3_base::testutil::{
201 FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
202 FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
203 };
204 use netstack3_base::{
205 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
206 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
207 };
208 use netstack3_hashmap::HashMap;
209
210 use super::*;
211 use crate::logic::FilterTimerId;
212 use crate::logic::nat::NatConfig;
213 use crate::state::validation::ValidRoutines;
214 use crate::state::{FilterPacketMetadata, IpRoutines, NatRoutines, OneWayBoolean, Routines};
215 use crate::{Interfaces, conntrack};
216
217 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
218
219 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
220
221 #[derive(Debug)]
222 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
223 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
224 );
225
226 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
227 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
228
229 #[derive(Clone, Debug)]
230 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
231 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
232 );
233
234 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
235 fn eq(&self, other: &Self) -> bool {
236 let Self(lhs) = self;
237 let Self(rhs) = other;
238 Weak::ptr_eq(lhs, rhs)
239 }
240 }
241
242 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
243
244 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
245 fn hash<H: Hasher>(&self, state: &mut H) {
246 let Self(this) = self;
247 this.as_ptr().hash(state)
248 }
249 }
250
251 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
252 type Strong = FakeAddressId<I>;
253
254 fn upgrade(&self) -> Option<Self::Strong> {
255 let Self(inner) = self;
256 inner.upgrade().map(FakeAddressId)
257 }
258
259 fn is_assigned(&self) -> bool {
260 let Self(inner) = self;
261 inner.strong_count() != 0
262 }
263 }
264
265 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
266 fn record<Inspector: netstack3_base::Inspector>(
267 &self,
268 _name: &str,
269 _inspector: &mut Inspector,
270 ) {
271 unimplemented!()
272 }
273 }
274
275 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
276 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
277
278 fn deref(&self) -> &Self::Target {
279 let Self(inner) = self;
280 inner.deref()
281 }
282 }
283
284 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
285 type Weak = FakeWeakAddressId<I>;
286
287 fn downgrade(&self) -> Self::Weak {
288 let Self(inner) = self;
289 FakeWeakAddressId(Arc::downgrade(inner))
290 }
291
292 fn addr(&self) -> IpDeviceAddr<I::Addr> {
293 let Self(inner) = self;
294
295 #[derive(GenericOverIp)]
296 #[generic_over_ip(I, Ip)]
297 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
298 I::map_ip(
299 WrapIn(inner.addr()),
300 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
301 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
302 )
303 }
304
305 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
306 let Self(inner) = self;
307 **inner
308 }
309 }
310
311 pub struct FakeCtx<I: TestIpExt> {
312 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
313 nat: FakeNatCtx<I>,
314 }
315
316 #[derive(Derivative)]
317 #[derivative(Default(bound = ""))]
318 pub struct FakeNatCtx<I: TestIpExt> {
319 pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
320 }
321
322 impl<I: TestIpExt> FakeCtx<I> {
323 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
324 Self {
325 state: State {
326 installed_routines: ValidRoutines::default(),
327 uninstalled_routines: Vec::default(),
328 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
329 nat_installed: OneWayBoolean::default(),
330 },
331 nat: FakeNatCtx::default(),
332 }
333 }
334
335 pub fn with_ip_routines(
336 bindings_ctx: &mut FakeBindingsCtx<I>,
337 routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
338 ) -> Self {
339 let (installed_routines, uninstalled_routines) =
340 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
341 .expect("invalid state");
342 Self {
343 state: State {
344 installed_routines,
345 uninstalled_routines,
346 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
347 nat_installed: OneWayBoolean::default(),
348 },
349 nat: FakeNatCtx::default(),
350 }
351 }
352
353 pub fn with_nat_routines_and_device_addrs(
354 bindings_ctx: &mut FakeBindingsCtx<I>,
355 routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
356 device_addrs: impl IntoIterator<
357 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
358 >,
359 ) -> Self {
360 let (installed_routines, uninstalled_routines) =
361 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
362 .expect("invalid state");
363 Self {
364 state: State {
365 installed_routines,
366 uninstalled_routines,
367 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
368 nat_installed: OneWayBoolean::TRUE,
369 },
370 nat: FakeNatCtx {
371 device_addrs: device_addrs
372 .into_iter()
373 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
374 .collect(),
375 },
376 }
377 }
378
379 pub fn conntrack(
380 &mut self,
381 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
382 &self.state.conntrack
383 }
384 }
385
386 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
387 type DeviceId = FakeMatcherDeviceId;
388 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
389 }
390
391 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
392 type AddressId = FakeAddressId<I>;
393 type WeakAddressId = FakeWeakAddressId<I>;
394 }
395
396 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
397 type NatCtx<'a> = FakeNatCtx<I>;
398
399 fn with_filter_state_and_nat_ctx<
400 O,
401 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
402 >(
403 &mut self,
404 cb: F,
405 ) -> O {
406 let Self { state, nat } = self;
407 cb(state, nat)
408 }
409 }
410
411 impl<I: TestIpExt> FakeNatCtx<I> {
412 pub fn new(
413 device_addrs: impl IntoIterator<
414 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
415 >,
416 ) -> Self {
417 Self {
418 device_addrs: device_addrs
419 .into_iter()
420 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
421 .collect(),
422 }
423 }
424 }
425
426 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
427 type DeviceId = FakeMatcherDeviceId;
428 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
429 }
430
431 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
432 type AddressId = FakeAddressId<I>;
433 type WeakAddressId = FakeWeakAddressId<I>;
434 }
435
436 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
437 fn get_local_addr_for_remote(
438 &mut self,
439 device_id: &Self::DeviceId,
440 _remote: Option<SpecifiedAddr<I::Addr>>,
441 ) -> Option<Self::AddressId> {
442 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
443 Some(FakeAddressId(primary.clone()))
444 }
445
446 fn get_address_id(
447 &mut self,
448 device_id: &Self::DeviceId,
449 addr: IpDeviceAddr<I::Addr>,
450 ) -> Option<Self::AddressId> {
451 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
452 let id = FakeAddressId(id.clone());
453 if id.addr() == addr { Some(id) } else { None }
454 }
455 }
456
457 pub struct FakeBindingsCtx<I: Ip> {
458 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
459 pub rng: FakeCryptoRng,
460 }
461
462 impl<I: Ip> FakeBindingsCtx<I> {
463 pub(crate) fn new() -> Self {
464 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
465 }
466
467 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
468 self.timer_ctx.instant.sleep(time_elapsed)
469 }
470 }
471
472 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
473 type Instant = FakeInstant;
474 type AtomicInstant = FakeAtomicInstant;
475 }
476
477 #[derive(Debug)]
478 pub struct FakeBindingsPacketMatcher {
479 num_calls: AtomicUsize,
480 result: bool,
481 }
482
483 impl FakeBindingsPacketMatcher {
484 pub fn new(result: bool) -> Arc<Self> {
485 Arc::new(Self { num_calls: AtomicUsize::new(0), result })
486 }
487
488 pub fn num_calls(&self) -> usize {
489 self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
490 }
491 }
492
493 impl BindingsPacketMatcher<FakeMatcherDeviceId> for FakeBindingsPacketMatcher {
494 fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
495 &self,
496 _packet: &P,
497 _interfaces: Interfaces<'_, FakeMatcherDeviceId>,
498 _socket_info: &impl FilterPacketMetadata,
499 ) -> bool {
500 let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
501 self.result
502 }
503 }
504
505 impl InspectableValue for FakeBindingsPacketMatcher {
506 fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
507 unimplemented!()
508 }
509 }
510
511 impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
512 type DeviceClass = FakeDeviceClass;
513 type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
514 }
515
516 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
517 fn now(&self) -> Self::Instant {
518 self.timer_ctx.now()
519 }
520 }
521
522 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
523 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
524 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
525 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
526 }
527
528 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
529 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
530 self.timer_ctx.new_timer(id)
531 }
532
533 fn schedule_timer_instant(
534 &mut self,
535 time: Self::Instant,
536 timer: &mut Self::Timer,
537 ) -> Option<Self::Instant> {
538 self.timer_ctx.schedule_timer_instant(time, timer)
539 }
540
541 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
542 self.timer_ctx.cancel_timer(timer)
543 }
544
545 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
546 self.timer_ctx.scheduled_instant(timer)
547 }
548
549 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
550 self.timer_ctx.unique_timer_id(timer)
551 }
552 }
553
554 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
555 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
556 &self,
557 f: F,
558 ) -> O {
559 f(&self.timer_ctx)
560 }
561
562 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
563 &mut self,
564 f: F,
565 ) -> O {
566 f(&mut self.timer_ctx)
567 }
568 }
569
570 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
571 type Rng<'a>
572 = FakeCryptoRng
573 where
574 Self: 'a;
575
576 fn rng(&mut self) -> Self::Rng<'_> {
577 self.rng.clone()
578 }
579 }
580}