Skip to main content

netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9
10pub use netstack3_base::Marks;
11pub use netstack3_base::socket::{EitherIpProto, SocketInfo};
12use netstack3_base::{
13    InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext,
14    MatcherBindingsTypes, RngContext, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
15};
16use packet::FragmentedByteSlice;
17use packet_formats::ip::IpExt;
18
19use crate::FilterIpExt;
20use crate::matchers::BindingsPacketMatcher;
21use crate::packets::FilterIpPacket;
22use crate::state::State;
23
24/// Trait defining required types for filtering provided by bindings.
25pub trait FilterBindingsTypes:
26    InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static
27{
28}
29
30impl<BT: InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static>
31    FilterBindingsTypes for BT
32{
33}
34
35/// Trait aggregating functionality required from bindings.
36pub trait FilterBindingsContext<D>:
37    TimerContext + RngContext + FilterBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher<D>>
38{
39}
40impl<D, BC> FilterBindingsContext<D> for BC
41where
42    BC: TimerContext + RngContext + FilterBindingsTypes,
43    BC::BindingsPacketMatcher: BindingsPacketMatcher<D>,
44{
45}
46
47/// The IP version-specific execution context for packet filtering.
48///
49/// This trait exists to abstract over access to the filtering state. It is
50/// useful to implement filtering logic in terms of this trait, as opposed to,
51/// for example, [`crate::logic::FilterHandler`] methods taking the state
52/// directly as an argument, because it allows Netstack3 Core to use lock
53/// ordering types to enforce that filtering state is only acquired at or before
54/// a given lock level, while keeping test code free of locking concerns.
55pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
56    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
57{
58    /// The execution context that allows the filtering engine to perform
59    /// Network Address Translation (NAT).
60    type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
61
62    /// Calls the function with a reference to filtering state.
63    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
64        &mut self,
65        cb: F,
66    ) -> O {
67        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
68    }
69
70    /// Calls the function with a reference to filtering state and the NAT
71    /// context.
72    fn with_filter_state_and_nat_ctx<
73        O,
74        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
75    >(
76        &mut self,
77        cb: F,
78    ) -> O;
79}
80
81/// The execution context for Network Address Translation (NAT).
82pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
83    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
84{
85    /// Returns the best local address for communicating with the remote.
86    fn get_local_addr_for_remote(
87        &mut self,
88        device_id: &Self::DeviceId,
89        remote: Option<SpecifiedAddr<I::Addr>>,
90    ) -> Option<Self::AddressId>;
91
92    /// Returns a strongly-held reference to the provided address, if it is assigned
93    /// to the specified device.
94    fn get_address_id(
95        &mut self,
96        device_id: &Self::DeviceId,
97        addr: IpDeviceAddr<I::Addr>,
98    ) -> Option<Self::AddressId>;
99}
100
101/// A context for mutably accessing all filtering state at once, to allow IPv4
102/// and IPv6 filtering state to be modified atomically.
103pub trait FilterContext<BT: FilterBindingsTypes>:
104    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
105    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
106{
107    /// Calls the function with a mutable reference to all filtering state.
108    fn with_all_filter_state_mut<
109        O,
110        F: FnOnce(
111            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
112            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
113        ) -> O,
114    >(
115        &mut self,
116        cb: F,
117    ) -> O;
118}
119
120/// Result returned from [`SocketOpsFilter::on_egress`].
121#[derive(Copy, Clone, Debug, Eq, PartialEq)]
122pub enum SocketEgressFilterResult {
123    /// Send the packet normally.
124    Pass {
125        /// Indicates that congestion should be signaled to the higher level protocol.
126        congestion: bool,
127    },
128
129    /// Drop the packet.
130    Drop {
131        /// Indicates that congestion should be signaled to the higher level protocol.
132        congestion: bool,
133    },
134}
135
136/// Result returned from [`SocketOpsFilter::on_ingress`].
137#[derive(Copy, Clone, Debug, Eq, PartialEq)]
138pub enum SocketIngressFilterResult {
139    /// Accept the packet.
140    Accept,
141
142    /// Drop the packet.
143    Drop,
144}
145
146/// Trait for a socket operations filter.
147pub trait SocketOpsFilter<D> {
148    /// Called on every outgoing packet originated from a local socket.
149    fn on_egress<I: FilterIpExt, P: FilterIpPacket<I>>(
150        &self,
151        packet: &P,
152        device: &D,
153        socket_info: SocketInfo,
154        marks: &Marks,
155    ) -> SocketEgressFilterResult;
156
157    /// Called on every incoming packet handled by a local socket.
158    fn on_ingress(
159        &self,
160        ip_version: IpVersion,
161        packet: FragmentedByteSlice<'_, &[u8]>,
162        device: &D,
163        socket_info: SocketInfo,
164        marks: &Marks,
165    ) -> SocketIngressFilterResult;
166}
167
168/// Implemented by bindings to provide socket operations filtering.
169pub trait SocketOpsFilterBindingContext<D>: TxMetadataBindingsTypes {
170    /// Returns the filter that should be called for socket ops.
171    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
172}
173
174#[cfg(any(test, feature = "testutils"))]
175impl<
176    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
177    Event: Debug + 'static,
178    State: 'static,
179    FrameMeta: 'static,
180    D,
181> SocketOpsFilterBindingContext<D>
182    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
183{
184    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
185        crate::testutil::NoOpSocketOpsFilter
186    }
187}
188
189#[cfg(test)]
190pub(crate) mod testutil {
191    use alloc::sync::{Arc, Weak};
192    use alloc::vec::Vec;
193    use core::hash::{Hash, Hasher};
194    use core::ops::Deref;
195    use core::sync::atomic::AtomicUsize;
196    use core::time::Duration;
197
198    use derivative::Derivative;
199    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
200    use netstack3_base::testutil::{
201        FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
202        FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
203    };
204    use netstack3_base::{
205        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
206        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
207    };
208    use netstack3_hashmap::HashMap;
209
210    use super::*;
211    use crate::logic::FilterTimerId;
212    use crate::logic::nat::NatConfig;
213    use crate::state::validation::ValidRoutines;
214    use crate::state::{FilterPacketMetadata, IpRoutines, NatRoutines, OneWayBoolean, Routines};
215    use crate::{Interfaces, conntrack};
216
217    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
218
219    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
220
221    #[derive(Debug)]
222    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
223        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
224    );
225
226    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
227    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
228
229    #[derive(Clone, Debug)]
230    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
231        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
232    );
233
234    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
235        fn eq(&self, other: &Self) -> bool {
236            let Self(lhs) = self;
237            let Self(rhs) = other;
238            Weak::ptr_eq(lhs, rhs)
239        }
240    }
241
242    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
243
244    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
245        fn hash<H: Hasher>(&self, state: &mut H) {
246            let Self(this) = self;
247            this.as_ptr().hash(state)
248        }
249    }
250
251    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
252        type Strong = FakeAddressId<I>;
253
254        fn upgrade(&self) -> Option<Self::Strong> {
255            let Self(inner) = self;
256            inner.upgrade().map(FakeAddressId)
257        }
258
259        fn is_assigned(&self) -> bool {
260            let Self(inner) = self;
261            inner.strong_count() != 0
262        }
263    }
264
265    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
266        fn record<Inspector: netstack3_base::Inspector>(
267            &self,
268            _name: &str,
269            _inspector: &mut Inspector,
270        ) {
271            unimplemented!()
272        }
273    }
274
275    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
276        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
277
278        fn deref(&self) -> &Self::Target {
279            let Self(inner) = self;
280            inner.deref()
281        }
282    }
283
284    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
285        type Weak = FakeWeakAddressId<I>;
286
287        fn downgrade(&self) -> Self::Weak {
288            let Self(inner) = self;
289            FakeWeakAddressId(Arc::downgrade(inner))
290        }
291
292        fn addr(&self) -> IpDeviceAddr<I::Addr> {
293            let Self(inner) = self;
294
295            #[derive(GenericOverIp)]
296            #[generic_over_ip(I, Ip)]
297            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
298            I::map_ip(
299                WrapIn(inner.addr()),
300                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
301                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
302            )
303        }
304
305        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
306            let Self(inner) = self;
307            **inner
308        }
309    }
310
311    pub struct FakeCtx<I: TestIpExt> {
312        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
313        nat: FakeNatCtx<I>,
314    }
315
316    #[derive(Derivative)]
317    #[derivative(Default(bound = ""))]
318    pub struct FakeNatCtx<I: TestIpExt> {
319        pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
320    }
321
322    impl<I: TestIpExt> FakeCtx<I> {
323        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
324            Self {
325                state: State {
326                    installed_routines: ValidRoutines::default(),
327                    uninstalled_routines: Vec::default(),
328                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
329                    nat_installed: OneWayBoolean::default(),
330                },
331                nat: FakeNatCtx::default(),
332            }
333        }
334
335        pub fn with_ip_routines(
336            bindings_ctx: &mut FakeBindingsCtx<I>,
337            routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
338        ) -> Self {
339            let (installed_routines, uninstalled_routines) =
340                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
341                    .expect("invalid state");
342            Self {
343                state: State {
344                    installed_routines,
345                    uninstalled_routines,
346                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
347                    nat_installed: OneWayBoolean::default(),
348                },
349                nat: FakeNatCtx::default(),
350            }
351        }
352
353        pub fn with_nat_routines_and_device_addrs(
354            bindings_ctx: &mut FakeBindingsCtx<I>,
355            routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
356            device_addrs: impl IntoIterator<
357                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
358            >,
359        ) -> Self {
360            let (installed_routines, uninstalled_routines) =
361                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
362                    .expect("invalid state");
363            Self {
364                state: State {
365                    installed_routines,
366                    uninstalled_routines,
367                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
368                    nat_installed: OneWayBoolean::TRUE,
369                },
370                nat: FakeNatCtx {
371                    device_addrs: device_addrs
372                        .into_iter()
373                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
374                        .collect(),
375                },
376            }
377        }
378
379        pub fn conntrack(
380            &mut self,
381        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
382            &self.state.conntrack
383        }
384    }
385
386    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
387        type DeviceId = FakeMatcherDeviceId;
388        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
389    }
390
391    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
392        type AddressId = FakeAddressId<I>;
393        type WeakAddressId = FakeWeakAddressId<I>;
394    }
395
396    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
397        type NatCtx<'a> = FakeNatCtx<I>;
398
399        fn with_filter_state_and_nat_ctx<
400            O,
401            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
402        >(
403            &mut self,
404            cb: F,
405        ) -> O {
406            let Self { state, nat } = self;
407            cb(state, nat)
408        }
409    }
410
411    impl<I: TestIpExt> FakeNatCtx<I> {
412        pub fn new(
413            device_addrs: impl IntoIterator<
414                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
415            >,
416        ) -> Self {
417            Self {
418                device_addrs: device_addrs
419                    .into_iter()
420                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
421                    .collect(),
422            }
423        }
424    }
425
426    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
427        type DeviceId = FakeMatcherDeviceId;
428        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
429    }
430
431    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
432        type AddressId = FakeAddressId<I>;
433        type WeakAddressId = FakeWeakAddressId<I>;
434    }
435
436    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
437        fn get_local_addr_for_remote(
438            &mut self,
439            device_id: &Self::DeviceId,
440            _remote: Option<SpecifiedAddr<I::Addr>>,
441        ) -> Option<Self::AddressId> {
442            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
443            Some(FakeAddressId(primary.clone()))
444        }
445
446        fn get_address_id(
447            &mut self,
448            device_id: &Self::DeviceId,
449            addr: IpDeviceAddr<I::Addr>,
450        ) -> Option<Self::AddressId> {
451            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
452            let id = FakeAddressId(id.clone());
453            if id.addr() == addr { Some(id) } else { None }
454        }
455    }
456
457    pub struct FakeBindingsCtx<I: Ip> {
458        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
459        pub rng: FakeCryptoRng,
460    }
461
462    impl<I: Ip> FakeBindingsCtx<I> {
463        pub(crate) fn new() -> Self {
464            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
465        }
466
467        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
468            self.timer_ctx.instant.sleep(time_elapsed)
469        }
470    }
471
472    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
473        type Instant = FakeInstant;
474        type AtomicInstant = FakeAtomicInstant;
475    }
476
477    #[derive(Debug)]
478    pub struct FakeBindingsPacketMatcher {
479        num_calls: AtomicUsize,
480        result: bool,
481    }
482
483    impl FakeBindingsPacketMatcher {
484        pub fn new(result: bool) -> Arc<Self> {
485            Arc::new(Self { num_calls: AtomicUsize::new(0), result })
486        }
487
488        pub fn num_calls(&self) -> usize {
489            self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
490        }
491    }
492
493    impl BindingsPacketMatcher<FakeMatcherDeviceId> for FakeBindingsPacketMatcher {
494        fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
495            &self,
496            _packet: &P,
497            _interfaces: Interfaces<'_, FakeMatcherDeviceId>,
498            _socket_info: &impl FilterPacketMetadata,
499        ) -> bool {
500            let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
501            self.result
502        }
503    }
504
505    impl InspectableValue for FakeBindingsPacketMatcher {
506        fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
507            unimplemented!()
508        }
509    }
510
511    impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
512        type DeviceClass = FakeDeviceClass;
513        type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
514    }
515
516    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
517        fn now(&self) -> Self::Instant {
518            self.timer_ctx.now()
519        }
520    }
521
522    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
523        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
524        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
525        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
526    }
527
528    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
529        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
530            self.timer_ctx.new_timer(id)
531        }
532
533        fn schedule_timer_instant(
534            &mut self,
535            time: Self::Instant,
536            timer: &mut Self::Timer,
537        ) -> Option<Self::Instant> {
538            self.timer_ctx.schedule_timer_instant(time, timer)
539        }
540
541        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
542            self.timer_ctx.cancel_timer(timer)
543        }
544
545        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
546            self.timer_ctx.scheduled_instant(timer)
547        }
548
549        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
550            self.timer_ctx.unique_timer_id(timer)
551        }
552    }
553
554    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
555        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
556            &self,
557            f: F,
558        ) -> O {
559            f(&self.timer_ctx)
560        }
561
562        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
563            &mut self,
564            f: F,
565        ) -> O {
566            f(&mut self.timer_ctx)
567        }
568    }
569
570    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
571        type Rng<'a>
572            = FakeCryptoRng
573        where
574            Self: 'a;
575
576        fn rng(&mut self) -> Self::Rng<'_> {
577            self.rng.clone()
578        }
579    }
580}