netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11    InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
12    StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::InterfaceProperties;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21/// Trait defining required types for filtering provided by bindings.
22///
23/// Allows rules that match on device class to be installed, storing the
24/// [`FilterBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
25/// Core to have Bindings provide the type since it is platform-specific.
26pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
27    /// The device class type for devices installed in the netstack.
28    type DeviceClass: Clone + Debug;
29}
30
31/// Trait aggregating functionality required from bindings.
32pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
33impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
34
35/// The IP version-specific execution context for packet filtering.
36///
37/// This trait exists to abstract over access to the filtering state. It is
38/// useful to implement filtering logic in terms of this trait, as opposed to,
39/// for example, [`crate::logic::FilterHandler`] methods taking the state
40/// directly as an argument, because it allows Netstack3 Core to use lock
41/// ordering types to enforce that filtering state is only acquired at or before
42/// a given lock level, while keeping test code free of locking concerns.
43pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
44    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
45{
46    /// The execution context that allows the filtering engine to perform
47    /// Network Address Translation (NAT).
48    type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
49
50    /// Calls the function with a reference to filtering state.
51    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
52        &mut self,
53        cb: F,
54    ) -> O {
55        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
56    }
57
58    /// Calls the function with a reference to filtering state and the NAT
59    /// context.
60    fn with_filter_state_and_nat_ctx<
61        O,
62        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
63    >(
64        &mut self,
65        cb: F,
66    ) -> O;
67}
68
69/// The execution context for Network Address Translation (NAT).
70pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
71    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
72{
73    /// Returns the best local address for communicating with the remote.
74    fn get_local_addr_for_remote(
75        &mut self,
76        device_id: &Self::DeviceId,
77        remote: Option<SpecifiedAddr<I::Addr>>,
78    ) -> Option<Self::AddressId>;
79
80    /// Returns a strongly-held reference to the provided address, if it is assigned
81    /// to the specified device.
82    fn get_address_id(
83        &mut self,
84        device_id: &Self::DeviceId,
85        addr: IpDeviceAddr<I::Addr>,
86    ) -> Option<Self::AddressId>;
87}
88
89/// A context for mutably accessing all filtering state at once, to allow IPv4
90/// and IPv6 filtering state to be modified atomically.
91pub trait FilterContext<BT: FilterBindingsTypes>:
92    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
93    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
94{
95    /// Calls the function with a mutable reference to all filtering state.
96    fn with_all_filter_state_mut<
97        O,
98        F: FnOnce(
99            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
100            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
101        ) -> O,
102    >(
103        &mut self,
104        cb: F,
105    ) -> O;
106}
107
108/// Result returned from [`SocketOpsFilter::on_egress`].
109#[derive(Copy, Clone, Debug, Eq, PartialEq)]
110pub enum SocketEgressFilterResult {
111    /// Send the packet normally.
112    Pass {
113        /// Indicates that congestion should be signaled to the higher level protocol.
114        congestion: bool,
115    },
116
117    /// Drop the packet.
118    Drop {
119        /// Indicates that congestion should be signaled to the higher level protocol.
120        congestion: bool,
121    },
122}
123
124/// Result returned from [`SocketOpsFilter::on_ingress`].
125#[derive(Copy, Clone, Debug, Eq, PartialEq)]
126pub enum SocketIngressFilterResult {
127    /// Accept the packet.
128    Accept,
129
130    /// Drop the packet.
131    Drop,
132}
133
134/// Trait for a socket operations filter.
135pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
136    /// Called on every outgoing packet originated from a local socket.
137    fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
138        &self,
139        packet: &P,
140        device: &D,
141        cookie: SocketCookie,
142        marks: &Marks,
143    ) -> SocketEgressFilterResult;
144
145    /// Called on every incoming packet handled by a local socket.
146    fn on_ingress(
147        &self,
148        ip_version: IpVersion,
149        packet: FragmentedByteSlice<'_, &[u8]>,
150        device: &D,
151        cookie: SocketCookie,
152        marks: &Marks,
153    ) -> SocketIngressFilterResult;
154}
155
156/// Implemented by bindings to provide socket operations filtering.
157pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
158    TxMetadataBindingsTypes
159{
160    /// Returns the filter that should be called for socket ops.
161    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
162}
163
164#[cfg(any(test, feature = "testutils"))]
165impl<
166    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
167    Event: Debug + 'static,
168    State: 'static,
169    FrameMeta: 'static,
170> FilterBindingsTypes
171    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
172{
173    type DeviceClass = ();
174}
175
176#[cfg(any(test, feature = "testutils"))]
177impl<
178    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
179    Event: Debug + 'static,
180    State: 'static,
181    FrameMeta: 'static,
182    D: StrongDeviceIdentifier,
183> SocketOpsFilterBindingContext<D>
184    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
185{
186    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
187        crate::testutil::NoOpSocketOpsFilter
188    }
189}
190
191#[cfg(test)]
192pub(crate) mod testutil {
193    use alloc::sync::{Arc, Weak};
194    use alloc::vec::Vec;
195    use core::hash::{Hash, Hasher};
196    use core::ops::Deref;
197    use core::time::Duration;
198
199    use derivative::Derivative;
200    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
201    use netstack3_base::testutil::{
202        FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
203        WithFakeTimerContext,
204    };
205    use netstack3_base::{
206        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
207        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
208    };
209    use netstack3_hashmap::HashMap;
210
211    use super::*;
212    use crate::conntrack;
213    use crate::logic::FilterTimerId;
214    use crate::logic::nat::NatConfig;
215    use crate::matchers::testutil::FakeDeviceId;
216    use crate::state::validation::ValidRoutines;
217    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
218
219    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
220
221    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
222
223    #[derive(Debug)]
224    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
225        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
226    );
227
228    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
229    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
230
231    #[derive(Clone, Debug)]
232    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
233        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
234    );
235
236    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
237        fn eq(&self, other: &Self) -> bool {
238            let Self(lhs) = self;
239            let Self(rhs) = other;
240            Weak::ptr_eq(lhs, rhs)
241        }
242    }
243
244    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
245
246    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
247        fn hash<H: Hasher>(&self, state: &mut H) {
248            let Self(this) = self;
249            this.as_ptr().hash(state)
250        }
251    }
252
253    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
254        type Strong = FakeAddressId<I>;
255
256        fn upgrade(&self) -> Option<Self::Strong> {
257            let Self(inner) = self;
258            inner.upgrade().map(FakeAddressId)
259        }
260
261        fn is_assigned(&self) -> bool {
262            let Self(inner) = self;
263            inner.strong_count() != 0
264        }
265    }
266
267    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
268        fn record<Inspector: netstack3_base::Inspector>(
269            &self,
270            _name: &str,
271            _inspector: &mut Inspector,
272        ) {
273            unimplemented!()
274        }
275    }
276
277    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
278        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
279
280        fn deref(&self) -> &Self::Target {
281            let Self(inner) = self;
282            inner.deref()
283        }
284    }
285
286    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
287        type Weak = FakeWeakAddressId<I>;
288
289        fn downgrade(&self) -> Self::Weak {
290            let Self(inner) = self;
291            FakeWeakAddressId(Arc::downgrade(inner))
292        }
293
294        fn addr(&self) -> IpDeviceAddr<I::Addr> {
295            let Self(inner) = self;
296
297            #[derive(GenericOverIp)]
298            #[generic_over_ip(I, Ip)]
299            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
300            I::map_ip(
301                WrapIn(inner.addr()),
302                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
303                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
304            )
305        }
306
307        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
308            let Self(inner) = self;
309            **inner
310        }
311    }
312
313    #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
314    pub enum FakeDeviceClass {
315        Ethernet,
316        Wlan,
317    }
318
319    pub struct FakeCtx<I: TestIpExt> {
320        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
321        nat: FakeNatCtx<I>,
322    }
323
324    #[derive(Derivative)]
325    #[derivative(Default(bound = ""))]
326    pub struct FakeNatCtx<I: TestIpExt> {
327        pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
328    }
329
330    impl<I: TestIpExt> FakeCtx<I> {
331        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
332            Self {
333                state: State {
334                    installed_routines: ValidRoutines::default(),
335                    uninstalled_routines: Vec::default(),
336                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
337                    nat_installed: OneWayBoolean::default(),
338                },
339                nat: FakeNatCtx::default(),
340            }
341        }
342
343        pub fn with_ip_routines(
344            bindings_ctx: &mut FakeBindingsCtx<I>,
345            routines: IpRoutines<I, FakeDeviceClass, ()>,
346        ) -> Self {
347            let (installed_routines, uninstalled_routines) =
348                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
349                    .expect("invalid state");
350            Self {
351                state: State {
352                    installed_routines,
353                    uninstalled_routines,
354                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
355                    nat_installed: OneWayBoolean::default(),
356                },
357                nat: FakeNatCtx::default(),
358            }
359        }
360
361        pub fn with_nat_routines_and_device_addrs(
362            bindings_ctx: &mut FakeBindingsCtx<I>,
363            routines: NatRoutines<I, FakeDeviceClass, ()>,
364            device_addrs: impl IntoIterator<
365                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
366            >,
367        ) -> Self {
368            let (installed_routines, uninstalled_routines) =
369                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
370                    .expect("invalid state");
371            Self {
372                state: State {
373                    installed_routines,
374                    uninstalled_routines,
375                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
376                    nat_installed: OneWayBoolean::TRUE,
377                },
378                nat: FakeNatCtx {
379                    device_addrs: device_addrs
380                        .into_iter()
381                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
382                        .collect(),
383                },
384            }
385        }
386
387        pub fn conntrack(
388            &mut self,
389        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
390            &self.state.conntrack
391        }
392    }
393
394    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
395        type DeviceId = FakeDeviceId;
396        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
397    }
398
399    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
400        type AddressId = FakeAddressId<I>;
401        type WeakAddressId = FakeWeakAddressId<I>;
402    }
403
404    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
405        type NatCtx<'a> = FakeNatCtx<I>;
406
407        fn with_filter_state_and_nat_ctx<
408            O,
409            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
410        >(
411            &mut self,
412            cb: F,
413        ) -> O {
414            let Self { state, nat } = self;
415            cb(state, nat)
416        }
417    }
418
419    impl<I: TestIpExt> FakeNatCtx<I> {
420        pub fn new(
421            device_addrs: impl IntoIterator<
422                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
423            >,
424        ) -> Self {
425            Self {
426                device_addrs: device_addrs
427                    .into_iter()
428                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
429                    .collect(),
430            }
431        }
432    }
433
434    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
435        type DeviceId = FakeDeviceId;
436        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
437    }
438
439    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
440        type AddressId = FakeAddressId<I>;
441        type WeakAddressId = FakeWeakAddressId<I>;
442    }
443
444    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
445        fn get_local_addr_for_remote(
446            &mut self,
447            device_id: &Self::DeviceId,
448            _remote: Option<SpecifiedAddr<I::Addr>>,
449        ) -> Option<Self::AddressId> {
450            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
451            Some(FakeAddressId(primary.clone()))
452        }
453
454        fn get_address_id(
455            &mut self,
456            device_id: &Self::DeviceId,
457            addr: IpDeviceAddr<I::Addr>,
458        ) -> Option<Self::AddressId> {
459            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
460            let id = FakeAddressId(id.clone());
461            if id.addr() == addr { Some(id) } else { None }
462        }
463    }
464
465    pub struct FakeBindingsCtx<I: Ip> {
466        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
467        pub rng: FakeCryptoRng,
468    }
469
470    impl<I: Ip> FakeBindingsCtx<I> {
471        pub(crate) fn new() -> Self {
472            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
473        }
474
475        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
476            self.timer_ctx.instant.sleep(time_elapsed)
477        }
478    }
479
480    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
481        type Instant = FakeInstant;
482        type AtomicInstant = FakeAtomicInstant;
483    }
484
485    impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
486        type DeviceClass = FakeDeviceClass;
487    }
488
489    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
490        fn now(&self) -> Self::Instant {
491            self.timer_ctx.now()
492        }
493    }
494
495    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
496        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
497        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
498        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
499    }
500
501    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
502        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
503            self.timer_ctx.new_timer(id)
504        }
505
506        fn schedule_timer_instant(
507            &mut self,
508            time: Self::Instant,
509            timer: &mut Self::Timer,
510        ) -> Option<Self::Instant> {
511            self.timer_ctx.schedule_timer_instant(time, timer)
512        }
513
514        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
515            self.timer_ctx.cancel_timer(timer)
516        }
517
518        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
519            self.timer_ctx.scheduled_instant(timer)
520        }
521
522        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
523            self.timer_ctx.unique_timer_id(timer)
524        }
525    }
526
527    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
528        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
529            &self,
530            f: F,
531        ) -> O {
532            f(&self.timer_ctx)
533        }
534
535        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
536            &mut self,
537            f: F,
538        ) -> O {
539            f(&mut self.timer_ctx)
540        }
541    }
542
543    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
544        type Rng<'a>
545            = FakeCryptoRng
546        where
547            Self: 'a;
548
549        fn rng(&mut self) -> Self::Rng<'_> {
550            self.rng.clone()
551        }
552    }
553}