1use core::fmt::{self, Display};
8use net_types::ip::{GenericOverIp, Ip, IpVersion, IpVersionMarker};
9use packet_formats::icmp::IcmpIpExt;
10
11#[derive(Debug, GenericOverIp, PartialEq, Clone)]
13#[generic_over_ip(I, Ip)]
14pub struct RawIpSocketIcmpFilter<I: IcmpIpExt> {
15 _marker: IpVersionMarker<I>,
16 filter: [u8; 32],
22}
23
24impl<I: IcmpIpExt> RawIpSocketIcmpFilter<I> {
25 pub const ALLOW_ALL: Self = Self::new([0; 32]);
27
28 pub const DENY_ALL: Self = Self::new([u8::MAX; 32]);
30
31 pub const fn new(filter: [u8; 32]) -> RawIpSocketIcmpFilter<I> {
36 RawIpSocketIcmpFilter { _marker: IpVersionMarker::new(), filter }
37 }
38
39 pub fn into_bytes(self) -> [u8; 32] {
43 let RawIpSocketIcmpFilter { _marker, filter } = self;
44 filter
45 }
46
47 pub(super) fn allows_type(&self, message_type: I::IcmpMessageType) -> bool {
49 let message_type: u8 = message_type.into();
50 let byte: u8 = message_type / 8;
51 let bit: u8 = message_type % 8;
52 (self.filter[usize::from(byte)] & (1 << bit)) == 0
56 }
57
58 #[cfg(test)]
60 fn set_type(&mut self, message_type: u8, allow: bool) {
61 let byte: u8 = message_type / 8;
62 let bit: u8 = message_type % 8;
63 match allow {
67 true => self.filter[usize::from(byte)] &= !(1 << bit),
68 false => self.filter[usize::from(byte)] |= 1 << bit,
69 }
70 }
71}
72
73impl<I: IcmpIpExt> Display for RawIpSocketIcmpFilter<I> {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 let iter = (0..=u8::MAX).filter_map(|i| {
76 let message_type = I::IcmpMessageType::try_from(i).ok()?;
77 self.allows_type(message_type).then_some(message_type)
78 });
79
80 match I::VERSION {
81 IpVersion::V4 => write!(f, "AllowedIcmpMessageTypes [")?,
82 IpVersion::V6 => write!(f, "AllowedIcmpv6MessageTypes [")?,
83 }
84 for (i, message_type) in iter.enumerate() {
85 if i == 0 {
86 write!(f, "\"{message_type:?}\"")?;
87 } else {
88 write!(f, ", \"{message_type:?}\"")?;
89 }
90 }
91 write!(f, "]")
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 use alloc::string::{String, ToString as _};
100 use alloc::vec::Vec;
101 use alloc::{format, vec};
102 use ip_test_macro::ip_test;
103 use net_types::ip::{Ipv4, Ipv6};
104 use packet_formats::icmp::{Icmpv4MessageType, Icmpv6MessageType};
105 use test_case::test_case;
106
107 fn build_precise_filter<I: IcmpIpExt>(
113 message_type: u8,
114 allow: bool,
115 ) -> RawIpSocketIcmpFilter<I> {
116 let mut filter = match allow {
117 true => RawIpSocketIcmpFilter::<I>::DENY_ALL,
118 false => RawIpSocketIcmpFilter::<I>::ALLOW_ALL,
119 };
120 filter.set_type(message_type, allow);
121 filter
122 }
123
124 #[test_case(21, true, [
126 0xFF, 0xFF, 0xDF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
127 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
128 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
129 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
130 ]; "allow_21")]
131 #[test_case(21, false, [
132 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00,
133 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
134 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
136 ]; "deny_21")]
137 fn build_filter(message_type: u8, allow: bool, expected: [u8; 32]) {
138 assert_eq!(build_precise_filter::<Ipv4>(message_type, allow).into_bytes(), expected);
139 assert_eq!(build_precise_filter::<Ipv6>(message_type, allow).into_bytes(), expected);
140 }
141
142 #[ip_test(I)]
143 fn icmp_filter_allows_type<I: IcmpIpExt>() {
144 for i in 0..=u8::MAX {
145 match I::IcmpMessageType::try_from(i) {
146 Err(_) => continue,
148 Ok(message_type) => {
149 let pass_filter = build_precise_filter::<I>(i, true);
150 let deny_filter = build_precise_filter::<I>(i, false);
151 assert!(pass_filter.allows_type(message_type), "Should allow MessageType:{i}");
152 assert!(!deny_filter.allows_type(message_type), "Should deny MessageType:{i}");
153 }
154 }
155 }
156 }
157
158 #[test_case(vec![] => "AllowedIcmpMessageTypes []".to_string(); "deny_all")]
159 #[test_case(vec![Icmpv4MessageType::EchoRequest] =>
160 "AllowedIcmpMessageTypes [\"Echo Request\"]".to_string(); "allow_echo_request")]
161 #[test_case(vec![Icmpv4MessageType::EchoReply, Icmpv4MessageType::EchoRequest] =>
162 "AllowedIcmpMessageTypes [\"Echo Reply\", \"Echo Request\"]".to_string();
163 "allow_echo_request_and_reply")]
164 fn icmpv4_filter_display(allowed_types: Vec<Icmpv4MessageType>) -> String {
165 let mut filter = RawIpSocketIcmpFilter::<Ipv4>::DENY_ALL;
166 for allowed_type in allowed_types {
167 filter.set_type(allowed_type.into(), true);
168 }
169 format!("{filter}")
170 }
171
172 #[test_case(vec![] => "AllowedIcmpv6MessageTypes []".to_string(); "deny_all")]
173 #[test_case(vec![Icmpv6MessageType::EchoRequest] =>
174 "AllowedIcmpv6MessageTypes [\"Echo Request\"]".to_string(); "allow_echo_request")]
175 #[test_case(vec![Icmpv6MessageType::EchoRequest, Icmpv6MessageType::EchoReply] =>
176 "AllowedIcmpv6MessageTypes [\"Echo Request\", \"Echo Reply\"]".to_string();
177 "allow_echo_request_and_reply")]
178 fn icmpv6_filter_display(allowed_types: Vec<Icmpv6MessageType>) -> String {
179 let mut filter = RawIpSocketIcmpFilter::<Ipv6>::DENY_ALL;
180 for allowed_type in allowed_types {
181 filter.set_type(allowed_type.into(), true);
182 }
183 format!("{filter}")
184 }
185}