use core::fmt::{self, Display};
use net_types::ip::{GenericOverIp, Ip, IpVersion, IpVersionMarker};
use packet_formats::icmp::IcmpIpExt;
#[derive(Debug, GenericOverIp, PartialEq, Clone)]
#[generic_over_ip(I, Ip)]
pub struct RawIpSocketIcmpFilter<I: IcmpIpExt> {
_marker: IpVersionMarker<I>,
filter: [u8; 32],
}
impl<I: IcmpIpExt> RawIpSocketIcmpFilter<I> {
pub const ALLOW_ALL: Self = Self::new([0; 32]);
pub const DENY_ALL: Self = Self::new([u8::MAX; 32]);
pub const fn new(filter: [u8; 32]) -> RawIpSocketIcmpFilter<I> {
RawIpSocketIcmpFilter { _marker: IpVersionMarker::new(), filter }
}
pub fn into_bytes(self) -> [u8; 32] {
let RawIpSocketIcmpFilter { _marker, filter } = self;
filter
}
pub(super) fn allows_type(&self, message_type: I::IcmpMessageType) -> bool {
let message_type: u8 = message_type.into();
let byte: u8 = message_type / 8;
let bit: u8 = message_type % 8;
(self.filter[usize::from(byte)] & (1 << bit)) == 0
}
#[cfg(test)]
fn set_type(&mut self, message_type: u8, allow: bool) {
let byte: u8 = message_type / 8;
let bit: u8 = message_type % 8;
match allow {
true => self.filter[usize::from(byte)] &= !(1 << bit),
false => self.filter[usize::from(byte)] |= 1 << bit,
}
}
}
impl<I: IcmpIpExt> Display for RawIpSocketIcmpFilter<I> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let iter = (0..=u8::MAX).filter_map(|i| {
let message_type = I::IcmpMessageType::try_from(i).ok()?;
self.allows_type(message_type).then_some(message_type)
});
match I::VERSION {
IpVersion::V4 => write!(f, "AllowedIcmpMessageTypes [")?,
IpVersion::V6 => write!(f, "AllowedIcmpv6MessageTypes [")?,
}
for (i, message_type) in iter.enumerate() {
if i == 0 {
write!(f, "\"{message_type:?}\"")?;
} else {
write!(f, ", \"{message_type:?}\"")?;
}
}
write!(f, "]")
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::{String, ToString as _};
use alloc::vec::Vec;
use alloc::{format, vec};
use ip_test_macro::ip_test;
use net_types::ip::{Ipv4, Ipv6};
use packet_formats::icmp::{Icmpv4MessageType, Icmpv6MessageType};
use test_case::test_case;
fn build_precise_filter<I: IcmpIpExt>(
message_type: u8,
allow: bool,
) -> RawIpSocketIcmpFilter<I> {
let mut filter = match allow {
true => RawIpSocketIcmpFilter::<I>::DENY_ALL,
false => RawIpSocketIcmpFilter::<I>::ALLOW_ALL,
};
filter.set_type(message_type, allow);
filter
}
#[test_case(21, true, [
0xFF, 0xFF, 0xDF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
]; "allow_21")]
#[test_case(21, false, [
0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]; "deny_21")]
fn build_filter(message_type: u8, allow: bool, expected: [u8; 32]) {
assert_eq!(build_precise_filter::<Ipv4>(message_type, allow).into_bytes(), expected);
assert_eq!(build_precise_filter::<Ipv6>(message_type, allow).into_bytes(), expected);
}
#[ip_test(I)]
fn icmp_filter_allows_type<I: IcmpIpExt>() {
for i in 0..=u8::MAX {
match I::IcmpMessageType::try_from(i) {
Err(_) => continue,
Ok(message_type) => {
let pass_filter = build_precise_filter::<I>(i, true);
let deny_filter = build_precise_filter::<I>(i, false);
assert!(pass_filter.allows_type(message_type), "Should allow MessageType:{i}");
assert!(!deny_filter.allows_type(message_type), "Should deny MessageType:{i}");
}
}
}
}
#[test_case(vec![] => "AllowedIcmpMessageTypes []".to_string(); "deny_all")]
#[test_case(vec![Icmpv4MessageType::EchoRequest] =>
"AllowedIcmpMessageTypes [\"Echo Request\"]".to_string(); "allow_echo_request")]
#[test_case(vec![Icmpv4MessageType::EchoReply, Icmpv4MessageType::EchoRequest] =>
"AllowedIcmpMessageTypes [\"Echo Reply\", \"Echo Request\"]".to_string();
"allow_echo_request_and_reply")]
fn icmpv4_filter_display(allowed_types: Vec<Icmpv4MessageType>) -> String {
let mut filter = RawIpSocketIcmpFilter::<Ipv4>::DENY_ALL;
for allowed_type in allowed_types {
filter.set_type(allowed_type.into(), true);
}
format!("{filter}")
}
#[test_case(vec![] => "AllowedIcmpv6MessageTypes []".to_string(); "deny_all")]
#[test_case(vec![Icmpv6MessageType::EchoRequest] =>
"AllowedIcmpv6MessageTypes [\"Echo Request\"]".to_string(); "allow_echo_request")]
#[test_case(vec![Icmpv6MessageType::EchoRequest, Icmpv6MessageType::EchoReply] =>
"AllowedIcmpv6MessageTypes [\"Echo Request\", \"Echo Reply\"]".to_string();
"allow_echo_request_and_reply")]
fn icmpv6_filter_display(allowed_types: Vec<Icmpv6MessageType>) -> String {
let mut filter = RawIpSocketIcmpFilter::<Ipv6>::DENY_ALL;
for allowed_type in allowed_types {
filter.set_type(allowed_type.into(), true);
}
format!("{filter}")
}
}