netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11    InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12    MatcherBindingsTypes, RngContext, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::BindingsPacketMatcher;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21/// Trait defining required types for filtering provided by bindings.
22pub trait FilterBindingsTypes:
23    InstantBindingsTypes
24    + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
25    + TimerBindingsTypes
26    + 'static
27{
28}
29
30impl<
31    BT: InstantBindingsTypes
32        + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
33        + TimerBindingsTypes
34        + 'static,
35> FilterBindingsTypes for BT
36{
37}
38
39/// Trait aggregating functionality required from bindings.
40pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
41impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
42
43/// The IP version-specific execution context for packet filtering.
44///
45/// This trait exists to abstract over access to the filtering state. It is
46/// useful to implement filtering logic in terms of this trait, as opposed to,
47/// for example, [`crate::logic::FilterHandler`] methods taking the state
48/// directly as an argument, because it allows Netstack3 Core to use lock
49/// ordering types to enforce that filtering state is only acquired at or before
50/// a given lock level, while keeping test code free of locking concerns.
51pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
52    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
53{
54    /// The execution context that allows the filtering engine to perform
55    /// Network Address Translation (NAT).
56    type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
57
58    /// Calls the function with a reference to filtering state.
59    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
60        &mut self,
61        cb: F,
62    ) -> O {
63        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
64    }
65
66    /// Calls the function with a reference to filtering state and the NAT
67    /// context.
68    fn with_filter_state_and_nat_ctx<
69        O,
70        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
71    >(
72        &mut self,
73        cb: F,
74    ) -> O;
75}
76
77/// The execution context for Network Address Translation (NAT).
78pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
79    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
80{
81    /// Returns the best local address for communicating with the remote.
82    fn get_local_addr_for_remote(
83        &mut self,
84        device_id: &Self::DeviceId,
85        remote: Option<SpecifiedAddr<I::Addr>>,
86    ) -> Option<Self::AddressId>;
87
88    /// Returns a strongly-held reference to the provided address, if it is assigned
89    /// to the specified device.
90    fn get_address_id(
91        &mut self,
92        device_id: &Self::DeviceId,
93        addr: IpDeviceAddr<I::Addr>,
94    ) -> Option<Self::AddressId>;
95}
96
97/// A context for mutably accessing all filtering state at once, to allow IPv4
98/// and IPv6 filtering state to be modified atomically.
99pub trait FilterContext<BT: FilterBindingsTypes>:
100    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
101    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
102{
103    /// Calls the function with a mutable reference to all filtering state.
104    fn with_all_filter_state_mut<
105        O,
106        F: FnOnce(
107            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
108            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
109        ) -> O,
110    >(
111        &mut self,
112        cb: F,
113    ) -> O;
114}
115
116/// Result returned from [`SocketOpsFilter::on_egress`].
117#[derive(Copy, Clone, Debug, Eq, PartialEq)]
118pub enum SocketEgressFilterResult {
119    /// Send the packet normally.
120    Pass {
121        /// Indicates that congestion should be signaled to the higher level protocol.
122        congestion: bool,
123    },
124
125    /// Drop the packet.
126    Drop {
127        /// Indicates that congestion should be signaled to the higher level protocol.
128        congestion: bool,
129    },
130}
131
132/// Result returned from [`SocketOpsFilter::on_ingress`].
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134pub enum SocketIngressFilterResult {
135    /// Accept the packet.
136    Accept,
137
138    /// Drop the packet.
139    Drop,
140}
141
142/// Trait for a socket operations filter.
143pub trait SocketOpsFilter<D> {
144    /// Called on every outgoing packet originated from a local socket.
145    fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
146        &self,
147        packet: &P,
148        device: &D,
149        cookie: SocketCookie,
150        marks: &Marks,
151    ) -> SocketEgressFilterResult;
152
153    /// Called on every incoming packet handled by a local socket.
154    fn on_ingress(
155        &self,
156        ip_version: IpVersion,
157        packet: FragmentedByteSlice<'_, &[u8]>,
158        device: &D,
159        cookie: SocketCookie,
160        marks: &Marks,
161    ) -> SocketIngressFilterResult;
162}
163
164/// Implemented by bindings to provide socket operations filtering.
165pub trait SocketOpsFilterBindingContext<D>: TxMetadataBindingsTypes {
166    /// Returns the filter that should be called for socket ops.
167    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
168}
169
170#[cfg(any(test, feature = "testutils"))]
171impl<
172    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
173    Event: Debug + 'static,
174    State: 'static,
175    FrameMeta: 'static,
176    D,
177> SocketOpsFilterBindingContext<D>
178    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
179{
180    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
181        crate::testutil::NoOpSocketOpsFilter
182    }
183}
184
185#[cfg(test)]
186pub(crate) mod testutil {
187    use alloc::sync::{Arc, Weak};
188    use alloc::vec::Vec;
189    use core::hash::{Hash, Hasher};
190    use core::ops::Deref;
191    use core::sync::atomic::AtomicUsize;
192    use core::time::Duration;
193
194    use derivative::Derivative;
195    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
196    use netstack3_base::testutil::{
197        FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
198        FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
199    };
200    use netstack3_base::{
201        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
202        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
203    };
204    use netstack3_hashmap::HashMap;
205
206    use super::*;
207    use crate::conntrack;
208    use crate::logic::FilterTimerId;
209    use crate::logic::nat::NatConfig;
210    use crate::state::validation::ValidRoutines;
211    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
212
213    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
214
215    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
216
217    #[derive(Debug)]
218    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
219        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
220    );
221
222    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
223    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
224
225    #[derive(Clone, Debug)]
226    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
227        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
228    );
229
230    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
231        fn eq(&self, other: &Self) -> bool {
232            let Self(lhs) = self;
233            let Self(rhs) = other;
234            Weak::ptr_eq(lhs, rhs)
235        }
236    }
237
238    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
239
240    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
241        fn hash<H: Hasher>(&self, state: &mut H) {
242            let Self(this) = self;
243            this.as_ptr().hash(state)
244        }
245    }
246
247    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
248        type Strong = FakeAddressId<I>;
249
250        fn upgrade(&self) -> Option<Self::Strong> {
251            let Self(inner) = self;
252            inner.upgrade().map(FakeAddressId)
253        }
254
255        fn is_assigned(&self) -> bool {
256            let Self(inner) = self;
257            inner.strong_count() != 0
258        }
259    }
260
261    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
262        fn record<Inspector: netstack3_base::Inspector>(
263            &self,
264            _name: &str,
265            _inspector: &mut Inspector,
266        ) {
267            unimplemented!()
268        }
269    }
270
271    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
272        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
273
274        fn deref(&self) -> &Self::Target {
275            let Self(inner) = self;
276            inner.deref()
277        }
278    }
279
280    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
281        type Weak = FakeWeakAddressId<I>;
282
283        fn downgrade(&self) -> Self::Weak {
284            let Self(inner) = self;
285            FakeWeakAddressId(Arc::downgrade(inner))
286        }
287
288        fn addr(&self) -> IpDeviceAddr<I::Addr> {
289            let Self(inner) = self;
290
291            #[derive(GenericOverIp)]
292            #[generic_over_ip(I, Ip)]
293            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
294            I::map_ip(
295                WrapIn(inner.addr()),
296                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
297                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
298            )
299        }
300
301        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
302            let Self(inner) = self;
303            **inner
304        }
305    }
306
307    pub struct FakeCtx<I: TestIpExt> {
308        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
309        nat: FakeNatCtx<I>,
310    }
311
312    #[derive(Derivative)]
313    #[derivative(Default(bound = ""))]
314    pub struct FakeNatCtx<I: TestIpExt> {
315        pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
316    }
317
318    impl<I: TestIpExt> FakeCtx<I> {
319        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
320            Self {
321                state: State {
322                    installed_routines: ValidRoutines::default(),
323                    uninstalled_routines: Vec::default(),
324                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
325                    nat_installed: OneWayBoolean::default(),
326                },
327                nat: FakeNatCtx::default(),
328            }
329        }
330
331        pub fn with_ip_routines(
332            bindings_ctx: &mut FakeBindingsCtx<I>,
333            routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
334        ) -> Self {
335            let (installed_routines, uninstalled_routines) =
336                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
337                    .expect("invalid state");
338            Self {
339                state: State {
340                    installed_routines,
341                    uninstalled_routines,
342                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
343                    nat_installed: OneWayBoolean::default(),
344                },
345                nat: FakeNatCtx::default(),
346            }
347        }
348
349        pub fn with_nat_routines_and_device_addrs(
350            bindings_ctx: &mut FakeBindingsCtx<I>,
351            routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
352            device_addrs: impl IntoIterator<
353                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
354            >,
355        ) -> Self {
356            let (installed_routines, uninstalled_routines) =
357                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
358                    .expect("invalid state");
359            Self {
360                state: State {
361                    installed_routines,
362                    uninstalled_routines,
363                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
364                    nat_installed: OneWayBoolean::TRUE,
365                },
366                nat: FakeNatCtx {
367                    device_addrs: device_addrs
368                        .into_iter()
369                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
370                        .collect(),
371                },
372            }
373        }
374
375        pub fn conntrack(
376            &mut self,
377        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
378            &self.state.conntrack
379        }
380    }
381
382    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
383        type DeviceId = FakeMatcherDeviceId;
384        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
385    }
386
387    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
388        type AddressId = FakeAddressId<I>;
389        type WeakAddressId = FakeWeakAddressId<I>;
390    }
391
392    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
393        type NatCtx<'a> = FakeNatCtx<I>;
394
395        fn with_filter_state_and_nat_ctx<
396            O,
397            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
398        >(
399            &mut self,
400            cb: F,
401        ) -> O {
402            let Self { state, nat } = self;
403            cb(state, nat)
404        }
405    }
406
407    impl<I: TestIpExt> FakeNatCtx<I> {
408        pub fn new(
409            device_addrs: impl IntoIterator<
410                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
411            >,
412        ) -> Self {
413            Self {
414                device_addrs: device_addrs
415                    .into_iter()
416                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
417                    .collect(),
418            }
419        }
420    }
421
422    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
423        type DeviceId = FakeMatcherDeviceId;
424        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
425    }
426
427    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
428        type AddressId = FakeAddressId<I>;
429        type WeakAddressId = FakeWeakAddressId<I>;
430    }
431
432    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
433        fn get_local_addr_for_remote(
434            &mut self,
435            device_id: &Self::DeviceId,
436            _remote: Option<SpecifiedAddr<I::Addr>>,
437        ) -> Option<Self::AddressId> {
438            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
439            Some(FakeAddressId(primary.clone()))
440        }
441
442        fn get_address_id(
443            &mut self,
444            device_id: &Self::DeviceId,
445            addr: IpDeviceAddr<I::Addr>,
446        ) -> Option<Self::AddressId> {
447            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
448            let id = FakeAddressId(id.clone());
449            if id.addr() == addr { Some(id) } else { None }
450        }
451    }
452
453    pub struct FakeBindingsCtx<I: Ip> {
454        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
455        pub rng: FakeCryptoRng,
456    }
457
458    impl<I: Ip> FakeBindingsCtx<I> {
459        pub(crate) fn new() -> Self {
460            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
461        }
462
463        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
464            self.timer_ctx.instant.sleep(time_elapsed)
465        }
466    }
467
468    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
469        type Instant = FakeInstant;
470        type AtomicInstant = FakeAtomicInstant;
471    }
472
473    #[derive(Debug)]
474    pub struct FakeBindingsPacketMatcher {
475        num_calls: AtomicUsize,
476        result: bool,
477    }
478
479    impl FakeBindingsPacketMatcher {
480        pub fn new(result: bool) -> Arc<Self> {
481            Arc::new(Self { num_calls: AtomicUsize::new(0), result })
482        }
483
484        pub fn num_calls(&self) -> usize {
485            self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
486        }
487    }
488
489    impl BindingsPacketMatcher for FakeBindingsPacketMatcher {
490        fn matches<I: FilterIpExt, P: IpPacket<I>>(&self, _packet: &P) -> bool {
491            let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
492            self.result
493        }
494    }
495
496    impl InspectableValue for FakeBindingsPacketMatcher {
497        fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
498            unimplemented!()
499        }
500    }
501
502    impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
503        type DeviceClass = FakeDeviceClass;
504        type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
505    }
506
507    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
508        fn now(&self) -> Self::Instant {
509            self.timer_ctx.now()
510        }
511    }
512
513    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
514        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
515        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
516        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
517    }
518
519    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
520        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
521            self.timer_ctx.new_timer(id)
522        }
523
524        fn schedule_timer_instant(
525            &mut self,
526            time: Self::Instant,
527            timer: &mut Self::Timer,
528        ) -> Option<Self::Instant> {
529            self.timer_ctx.schedule_timer_instant(time, timer)
530        }
531
532        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
533            self.timer_ctx.cancel_timer(timer)
534        }
535
536        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
537            self.timer_ctx.scheduled_instant(timer)
538        }
539
540        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
541            self.timer_ctx.unique_timer_id(timer)
542        }
543    }
544
545    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
546        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
547            &self,
548            f: F,
549        ) -> O {
550            f(&self.timer_ctx)
551        }
552
553        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
554            &mut self,
555            f: F,
556        ) -> O {
557            f(&mut self.timer_ctx)
558        }
559    }
560
561    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
562        type Rng<'a>
563            = FakeCryptoRng
564        where
565            Self: 'a;
566
567        fn rng(&mut self) -> Self::Rng<'_> {
568            self.rng.clone()
569        }
570    }
571}