1use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
12 StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::InterfaceProperties;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
27 type DeviceClass: Clone + Debug;
29}
30
31pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
33impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
34
35pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
44 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
45{
46 type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
49
50 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
52 &mut self,
53 cb: F,
54 ) -> O {
55 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
56 }
57
58 fn with_filter_state_and_nat_ctx<
61 O,
62 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
63 >(
64 &mut self,
65 cb: F,
66 ) -> O;
67}
68
69pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
71 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
72{
73 fn get_local_addr_for_remote(
75 &mut self,
76 device_id: &Self::DeviceId,
77 remote: Option<SpecifiedAddr<I::Addr>>,
78 ) -> Option<Self::AddressId>;
79
80 fn get_address_id(
83 &mut self,
84 device_id: &Self::DeviceId,
85 addr: IpDeviceAddr<I::Addr>,
86 ) -> Option<Self::AddressId>;
87}
88
89pub trait FilterContext<BT: FilterBindingsTypes>:
92 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
93 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
94{
95 fn with_all_filter_state_mut<
97 O,
98 F: FnOnce(
99 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
100 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
101 ) -> O,
102 >(
103 &mut self,
104 cb: F,
105 ) -> O;
106}
107
108#[derive(Copy, Clone, Debug, Eq, PartialEq)]
110pub enum SocketEgressFilterResult {
111 Pass {
113 congestion: bool,
115 },
116
117 Drop {
119 congestion: bool,
121 },
122}
123
124#[derive(Copy, Clone, Debug, Eq, PartialEq)]
126pub enum SocketIngressFilterResult {
127 Accept,
129
130 Drop,
132}
133
134pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
136 fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
138 &self,
139 packet: &P,
140 device: &D,
141 cookie: SocketCookie,
142 marks: &Marks,
143 ) -> SocketEgressFilterResult;
144
145 fn on_ingress(
147 &self,
148 ip_version: IpVersion,
149 packet: FragmentedByteSlice<'_, &[u8]>,
150 device: &D,
151 cookie: SocketCookie,
152 marks: &Marks,
153 ) -> SocketIngressFilterResult;
154}
155
156pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
158 TxMetadataBindingsTypes
159{
160 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
162}
163
164#[cfg(any(test, feature = "testutils"))]
165impl<
166 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
167 Event: Debug + 'static,
168 State: 'static,
169 FrameMeta: 'static,
170> FilterBindingsTypes
171 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
172{
173 type DeviceClass = ();
174}
175
176#[cfg(any(test, feature = "testutils"))]
177impl<
178 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
179 Event: Debug + 'static,
180 State: 'static,
181 FrameMeta: 'static,
182 D: StrongDeviceIdentifier,
183> SocketOpsFilterBindingContext<D>
184 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
185{
186 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
187 crate::testutil::NoOpSocketOpsFilter
188 }
189}
190
191#[cfg(test)]
192pub(crate) mod testutil {
193 use alloc::sync::{Arc, Weak};
194 use alloc::vec::Vec;
195 use core::hash::{Hash, Hasher};
196 use core::ops::Deref;
197 use core::time::Duration;
198
199 use derivative::Derivative;
200 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
201 use netstack3_base::testutil::{
202 FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
203 WithFakeTimerContext,
204 };
205 use netstack3_base::{
206 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
207 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
208 };
209 use netstack3_hashmap::HashMap;
210
211 use super::*;
212 use crate::conntrack;
213 use crate::logic::FilterTimerId;
214 use crate::logic::nat::NatConfig;
215 use crate::matchers::testutil::FakeDeviceId;
216 use crate::state::validation::ValidRoutines;
217 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
218
219 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
220
221 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
222
223 #[derive(Debug)]
224 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
225 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
226 );
227
228 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
229 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
230
231 #[derive(Clone, Debug)]
232 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
233 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
234 );
235
236 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
237 fn eq(&self, other: &Self) -> bool {
238 let Self(lhs) = self;
239 let Self(rhs) = other;
240 Weak::ptr_eq(lhs, rhs)
241 }
242 }
243
244 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
245
246 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
247 fn hash<H: Hasher>(&self, state: &mut H) {
248 let Self(this) = self;
249 this.as_ptr().hash(state)
250 }
251 }
252
253 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
254 type Strong = FakeAddressId<I>;
255
256 fn upgrade(&self) -> Option<Self::Strong> {
257 let Self(inner) = self;
258 inner.upgrade().map(FakeAddressId)
259 }
260
261 fn is_assigned(&self) -> bool {
262 let Self(inner) = self;
263 inner.strong_count() != 0
264 }
265 }
266
267 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
268 fn record<Inspector: netstack3_base::Inspector>(
269 &self,
270 _name: &str,
271 _inspector: &mut Inspector,
272 ) {
273 unimplemented!()
274 }
275 }
276
277 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
278 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
279
280 fn deref(&self) -> &Self::Target {
281 let Self(inner) = self;
282 inner.deref()
283 }
284 }
285
286 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
287 type Weak = FakeWeakAddressId<I>;
288
289 fn downgrade(&self) -> Self::Weak {
290 let Self(inner) = self;
291 FakeWeakAddressId(Arc::downgrade(inner))
292 }
293
294 fn addr(&self) -> IpDeviceAddr<I::Addr> {
295 let Self(inner) = self;
296
297 #[derive(GenericOverIp)]
298 #[generic_over_ip(I, Ip)]
299 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
300 I::map_ip(
301 WrapIn(inner.addr()),
302 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
303 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
304 )
305 }
306
307 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
308 let Self(inner) = self;
309 **inner
310 }
311 }
312
313 #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
314 pub enum FakeDeviceClass {
315 Ethernet,
316 Wlan,
317 }
318
319 pub struct FakeCtx<I: TestIpExt> {
320 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
321 nat: FakeNatCtx<I>,
322 }
323
324 #[derive(Derivative)]
325 #[derivative(Default(bound = ""))]
326 pub struct FakeNatCtx<I: TestIpExt> {
327 pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
328 }
329
330 impl<I: TestIpExt> FakeCtx<I> {
331 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
332 Self {
333 state: State {
334 installed_routines: ValidRoutines::default(),
335 uninstalled_routines: Vec::default(),
336 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
337 nat_installed: OneWayBoolean::default(),
338 },
339 nat: FakeNatCtx::default(),
340 }
341 }
342
343 pub fn with_ip_routines(
344 bindings_ctx: &mut FakeBindingsCtx<I>,
345 routines: IpRoutines<I, FakeDeviceClass, ()>,
346 ) -> Self {
347 let (installed_routines, uninstalled_routines) =
348 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
349 .expect("invalid state");
350 Self {
351 state: State {
352 installed_routines,
353 uninstalled_routines,
354 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
355 nat_installed: OneWayBoolean::default(),
356 },
357 nat: FakeNatCtx::default(),
358 }
359 }
360
361 pub fn with_nat_routines_and_device_addrs(
362 bindings_ctx: &mut FakeBindingsCtx<I>,
363 routines: NatRoutines<I, FakeDeviceClass, ()>,
364 device_addrs: impl IntoIterator<
365 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
366 >,
367 ) -> Self {
368 let (installed_routines, uninstalled_routines) =
369 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
370 .expect("invalid state");
371 Self {
372 state: State {
373 installed_routines,
374 uninstalled_routines,
375 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
376 nat_installed: OneWayBoolean::TRUE,
377 },
378 nat: FakeNatCtx {
379 device_addrs: device_addrs
380 .into_iter()
381 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
382 .collect(),
383 },
384 }
385 }
386
387 pub fn conntrack(
388 &mut self,
389 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
390 &self.state.conntrack
391 }
392 }
393
394 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
395 type DeviceId = FakeDeviceId;
396 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
397 }
398
399 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
400 type AddressId = FakeAddressId<I>;
401 type WeakAddressId = FakeWeakAddressId<I>;
402 }
403
404 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
405 type NatCtx<'a> = FakeNatCtx<I>;
406
407 fn with_filter_state_and_nat_ctx<
408 O,
409 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
410 >(
411 &mut self,
412 cb: F,
413 ) -> O {
414 let Self { state, nat } = self;
415 cb(state, nat)
416 }
417 }
418
419 impl<I: TestIpExt> FakeNatCtx<I> {
420 pub fn new(
421 device_addrs: impl IntoIterator<
422 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
423 >,
424 ) -> Self {
425 Self {
426 device_addrs: device_addrs
427 .into_iter()
428 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
429 .collect(),
430 }
431 }
432 }
433
434 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
435 type DeviceId = FakeDeviceId;
436 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
437 }
438
439 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
440 type AddressId = FakeAddressId<I>;
441 type WeakAddressId = FakeWeakAddressId<I>;
442 }
443
444 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
445 fn get_local_addr_for_remote(
446 &mut self,
447 device_id: &Self::DeviceId,
448 _remote: Option<SpecifiedAddr<I::Addr>>,
449 ) -> Option<Self::AddressId> {
450 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
451 Some(FakeAddressId(primary.clone()))
452 }
453
454 fn get_address_id(
455 &mut self,
456 device_id: &Self::DeviceId,
457 addr: IpDeviceAddr<I::Addr>,
458 ) -> Option<Self::AddressId> {
459 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
460 let id = FakeAddressId(id.clone());
461 if id.addr() == addr { Some(id) } else { None }
462 }
463 }
464
465 pub struct FakeBindingsCtx<I: Ip> {
466 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
467 pub rng: FakeCryptoRng,
468 }
469
470 impl<I: Ip> FakeBindingsCtx<I> {
471 pub(crate) fn new() -> Self {
472 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
473 }
474
475 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
476 self.timer_ctx.instant.sleep(time_elapsed)
477 }
478 }
479
480 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
481 type Instant = FakeInstant;
482 type AtomicInstant = FakeAtomicInstant;
483 }
484
485 impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
486 type DeviceClass = FakeDeviceClass;
487 }
488
489 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
490 fn now(&self) -> Self::Instant {
491 self.timer_ctx.now()
492 }
493 }
494
495 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
496 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
497 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
498 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
499 }
500
501 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
502 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
503 self.timer_ctx.new_timer(id)
504 }
505
506 fn schedule_timer_instant(
507 &mut self,
508 time: Self::Instant,
509 timer: &mut Self::Timer,
510 ) -> Option<Self::Instant> {
511 self.timer_ctx.schedule_timer_instant(time, timer)
512 }
513
514 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
515 self.timer_ctx.cancel_timer(timer)
516 }
517
518 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
519 self.timer_ctx.scheduled_instant(timer)
520 }
521
522 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
523 self.timer_ctx.unique_timer_id(timer)
524 }
525 }
526
527 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
528 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
529 &self,
530 f: F,
531 ) -> O {
532 f(&self.timer_ctx)
533 }
534
535 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
536 &mut self,
537 f: F,
538 ) -> O {
539 f(&mut self.timer_ctx)
540 }
541 }
542
543 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
544 type Rng<'a>
545 = FakeCryptoRng
546 where
547 Self: 'a;
548
549 fn rng(&mut self) -> Self::Rng<'_> {
550 self.rng.clone()
551 }
552 }
553}