netstack3_ip/raw/
filter.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Declare types related to the per-socket filters of raw IP sockets.
6
7use core::fmt::{self, Display};
8use net_types::ip::{GenericOverIp, Ip, IpVersion, IpVersionMarker};
9use packet_formats::icmp::IcmpIpExt;
10
11/// An ICMP filter installed on a raw IP socket.
12#[derive(Debug, GenericOverIp, PartialEq, Clone)]
13#[generic_over_ip(I, Ip)]
14pub struct RawIpSocketIcmpFilter<I: IcmpIpExt> {
15    _marker: IpVersionMarker<I>,
16    /// The raw 256-bit filter. If bit `n` is set, ICMP messages with type `n`
17    /// will be filtered.
18    ///
19    /// Note: if bit `n` is an invalid message type, the packet will be dropped
20    /// regardless of if the bit is set or not.
21    filter: [u8; 32],
22}
23
24impl<I: IcmpIpExt> RawIpSocketIcmpFilter<I> {
25    /// An ICMP filter that allows all message types to be delivered.
26    pub const ALLOW_ALL: Self = Self::new([0; 32]);
27
28    /// An ICMP filter that prevents all message types from being delivered.
29    pub const DENY_ALL: Self = Self::new([u8::MAX; 32]);
30
31    /// Construct a `RawIpSocketIcmpFilter` from the raw bytes.
32    ///
33    /// The array is expected to be little endian. E.g. byte 0 in the array is
34    /// used to control filters for types 0-7.
35    pub const fn new(filter: [u8; 32]) -> RawIpSocketIcmpFilter<I> {
36        RawIpSocketIcmpFilter { _marker: IpVersionMarker::new(), filter }
37    }
38
39    /// Convert the `RawIpSocketIcmpFilter` into the raw bytes.
40    ///
41    /// The array is returned in little endian format.
42    pub fn into_bytes(self) -> [u8; 32] {
43        let RawIpSocketIcmpFilter { _marker, filter } = self;
44        filter
45    }
46
47    /// True if this `RawIpSocketIcmpFilter` allows ICMP messages of the given type.
48    pub(super) fn allows_type(&self, message_type: I::IcmpMessageType) -> bool {
49        let message_type: u8 = message_type.into();
50        let byte: u8 = message_type / 8;
51        let bit: u8 = message_type % 8;
52        // NB: message_type has a max value of 255 (u8::MAX); once divided by 8
53        // its maximum value becomes 31, so `byte` cannot exceed the array
54        // bounds on `self.filter`, which has a length of 32.
55        (self.filter[usize::from(byte)] & (1 << bit)) == 0
56    }
57
58    /// Set whether the given message type is allowed.
59    #[cfg(test)]
60    fn set_type(&mut self, message_type: u8, allow: bool) {
61        let byte: u8 = message_type / 8;
62        let bit: u8 = message_type % 8;
63        // NB: message_type has a max value of 255 (u8::MAX); once divided by 8
64        // its maximum value becomes 31, so `byte` cannot exceed the array
65        // bounds on `self.filter`, which has a length of 32.
66        match allow {
67            true => self.filter[usize::from(byte)] &= !(1 << bit),
68            false => self.filter[usize::from(byte)] |= 1 << bit,
69        }
70    }
71}
72
73impl<I: IcmpIpExt> Display for RawIpSocketIcmpFilter<I> {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        let iter = (0..=u8::MAX).filter_map(|i| {
76            let message_type = I::IcmpMessageType::try_from(i).ok()?;
77            self.allows_type(message_type).then_some(message_type)
78        });
79
80        match I::VERSION {
81            IpVersion::V4 => write!(f, "AllowedIcmpMessageTypes [")?,
82            IpVersion::V6 => write!(f, "AllowedIcmpv6MessageTypes [")?,
83        }
84        for (i, message_type) in iter.enumerate() {
85            if i == 0 {
86                write!(f, "\"{message_type:?}\"")?;
87            } else {
88                write!(f, ", \"{message_type:?}\"")?;
89            }
90        }
91        write!(f, "]")
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    use alloc::string::{String, ToString as _};
100    use alloc::vec::Vec;
101    use alloc::{format, vec};
102    use ip_test_macro::ip_test;
103    use net_types::ip::{Ipv4, Ipv6};
104    use packet_formats::icmp::{Icmpv4MessageType, Icmpv6MessageType};
105    use test_case::test_case;
106
107    /// Builds a filter to precisely allow/disallow a given message type.
108    ///
109    /// E.g. when allow is true, the filter will be all 1s, except for the bit
110    /// at message type. The filter will have the opposite value when allow is
111    /// false.
112    fn build_precise_filter<I: IcmpIpExt>(
113        message_type: u8,
114        allow: bool,
115    ) -> RawIpSocketIcmpFilter<I> {
116        let mut filter = match allow {
117            true => RawIpSocketIcmpFilter::<I>::DENY_ALL,
118            false => RawIpSocketIcmpFilter::<I>::ALLOW_ALL,
119        };
120        filter.set_type(message_type, allow);
121        filter
122    }
123
124    // NB: the test helper is complex enough to warrant a test of it's own.
125    #[test_case(21, true, [
126        0xFF, 0xFF, 0xDF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
127        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
128        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
129        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
130        ]; "allow_21")]
131    #[test_case(21, false, [
132        0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00,
133        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
134        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
136        ]; "deny_21")]
137    fn build_filter(message_type: u8, allow: bool, expected: [u8; 32]) {
138        assert_eq!(build_precise_filter::<Ipv4>(message_type, allow).into_bytes(), expected);
139        assert_eq!(build_precise_filter::<Ipv6>(message_type, allow).into_bytes(), expected);
140    }
141
142    #[ip_test(I)]
143    fn icmp_filter_allows_type<I: IcmpIpExt>() {
144        for i in 0..=u8::MAX {
145            match I::IcmpMessageType::try_from(i) {
146                // This isn't a valid message type; skip testing it.
147                Err(_) => continue,
148                Ok(message_type) => {
149                    let pass_filter = build_precise_filter::<I>(i, true);
150                    let deny_filter = build_precise_filter::<I>(i, false);
151                    assert!(pass_filter.allows_type(message_type), "Should allow MessageType:{i}");
152                    assert!(!deny_filter.allows_type(message_type), "Should deny MessageType:{i}");
153                }
154            }
155        }
156    }
157
158    #[test_case(vec![] => "AllowedIcmpMessageTypes []".to_string(); "deny_all")]
159    #[test_case(vec![Icmpv4MessageType::EchoRequest] =>
160        "AllowedIcmpMessageTypes [\"Echo Request\"]".to_string(); "allow_echo_request")]
161    #[test_case(vec![Icmpv4MessageType::EchoReply, Icmpv4MessageType::EchoRequest] =>
162        "AllowedIcmpMessageTypes [\"Echo Reply\", \"Echo Request\"]".to_string();
163        "allow_echo_request_and_reply")]
164    fn icmpv4_filter_display(allowed_types: Vec<Icmpv4MessageType>) -> String {
165        let mut filter = RawIpSocketIcmpFilter::<Ipv4>::DENY_ALL;
166        for allowed_type in allowed_types {
167            filter.set_type(allowed_type.into(), true);
168        }
169        format!("{filter}")
170    }
171
172    #[test_case(vec![] => "AllowedIcmpv6MessageTypes []".to_string(); "deny_all")]
173    #[test_case(vec![Icmpv6MessageType::EchoRequest] =>
174        "AllowedIcmpv6MessageTypes [\"Echo Request\"]".to_string(); "allow_echo_request")]
175    #[test_case(vec![Icmpv6MessageType::EchoRequest, Icmpv6MessageType::EchoReply] =>
176        "AllowedIcmpv6MessageTypes [\"Echo Request\", \"Echo Reply\"]".to_string();
177        "allow_echo_request_and_reply")]
178    fn icmpv6_filter_display(allowed_types: Vec<Icmpv6MessageType>) -> String {
179        let mut filter = RawIpSocketIcmpFilter::<Ipv6>::DENY_ALL;
180        for allowed_type in allowed_types {
181            filter.set_type(allowed_type.into(), true);
182        }
183        format!("{filter}")
184    }
185}