use core::ops::Deref;
use net_types::ip::Ipv4Addr;
use packet::records::{ParsedRecord, RecordParseResult, Records, RecordsImpl, RecordsImplLayout};
use packet::{BufferView, ParsablePacket, ParseMetadata};
use zerocopy::byteorder::network_endian::U16;
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, SplitByteSlice, Unaligned};
use super::{
parse_v3_possible_floating_point, peek_message_type, IgmpMessage, IgmpNonEmptyBody,
IgmpResponseTimeV2, IgmpResponseTimeV3,
};
use crate::error::{ParseError, UnrecognizedProtocolCode};
use crate::igmp::MessageType;
create_protocol_enum!(
#[allow(missing_docs)]
#[derive(PartialEq, Copy, Clone)]
pub enum IgmpMessageType: u8 {
MembershipQuery, 0x11, "Membership Query";
MembershipReportV1,0x12, "Membership Report V1";
MembershipReportV2,0x16, "Membership Report V2";
MembershipReportV3,0x22, "Membership Report V3";
LeaveGroup, 0x17, "Leave Group";
}
);
macro_rules! impl_igmp_simple_message_type {
($type:ident, $code:tt, $fixed_header:ident) => {
impl<B> MessageType<B> for $type {
type FixedHeader = $fixed_header;
const TYPE: IgmpMessageType = IgmpMessageType::$code;
type MaxRespTime = ();
declare_no_body!();
}
};
}
macro_rules! declare_no_body {
() => {
type VariableBody = ();
fn parse_body<BV: BufferView<B>>(
_header: &Self::FixedHeader,
bytes: BV,
) -> Result<Self::VariableBody, ParseError>
where
B: SplitByteSlice,
{
if bytes.len() != 0 {
Err(ParseError::NotExpected)
} else {
Ok(())
}
}
fn body_bytes(_body: &Self::VariableBody) -> &[u8]
where
B: SplitByteSlice,
{
&[]
}
};
}
#[derive(Copy, Clone, Debug)]
pub struct IgmpMembershipQueryV2;
impl<B> MessageType<B> for IgmpMembershipQueryV2 {
type FixedHeader = Ipv4Addr;
type MaxRespTime = IgmpResponseTimeV2;
const TYPE: IgmpMessageType = IgmpMessageType::MembershipQuery;
declare_no_body!();
}
#[derive(Copy, Clone, Debug, IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned)]
#[repr(C)]
pub struct MembershipQueryData {
group_address: Ipv4Addr,
sqrv: u8,
qqic: u8,
number_of_sources: U16,
}
impl MembershipQueryData {
#[allow(dead_code)]
const S_FLAG: u8 = (1 << 3);
#[allow(dead_code)]
const QRV_MSK: u8 = 0x07;
pub(crate) const _DEFAULT_QRV: u8 = 2;
pub const DEFAULT_QUERY_INTERVAL: core::time::Duration = core::time::Duration::from_secs(1250);
pub fn number_of_sources(self) -> u16 {
self.number_of_sources.get()
}
pub fn suppress_router_side_processing(self) -> bool {
(self.sqrv & Self::S_FLAG) != 0
}
pub fn querier_robustness_variable(self) -> u8 {
self.sqrv & Self::QRV_MSK
}
pub fn querier_query_interval(self) -> core::time::Duration {
core::time::Duration::from_secs(parse_v3_possible_floating_point(self.qqic).into())
}
}
#[derive(Copy, Clone, Debug)]
pub struct IgmpMembershipQueryV3;
impl<B> IgmpNonEmptyBody for Ref<B, [Ipv4Addr]> {}
impl<B> MessageType<B> for IgmpMembershipQueryV3 {
type FixedHeader = MembershipQueryData;
type VariableBody = Ref<B, [Ipv4Addr]>;
type MaxRespTime = IgmpResponseTimeV3;
const TYPE: IgmpMessageType = IgmpMessageType::MembershipQuery;
fn parse_body<BV: BufferView<B>>(
header: &Self::FixedHeader,
mut bytes: BV,
) -> Result<Self::VariableBody, ParseError>
where
B: SplitByteSlice,
{
bytes
.take_slice_front::<Ipv4Addr>(header.number_of_sources() as usize)
.ok_or(ParseError::Format)
}
fn body_bytes(body: &Self::VariableBody) -> &[u8]
where
B: SplitByteSlice,
{
Ref::bytes(body)
}
}
#[derive(Copy, Clone, Debug, IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned)]
#[repr(C)]
pub struct MembershipReportV3Data {
_reserved: [u8; 2],
number_of_group_records: U16,
}
impl MembershipReportV3Data {
pub fn number_of_group_records(self) -> u16 {
self.number_of_group_records.get()
}
}
create_protocol_enum!(
#[allow(missing_docs)]
#[derive(PartialEq, Copy, Clone)]
pub enum IgmpGroupRecordType: u8 {
ModeIsInclude, 0x01, "Mode Is Include";
ModeIsExclude, 0x02, "Mode Is Exclude";
ChangeToIncludeMode, 0x03, "Change To Include Mode";
ChangeToExcludeMode, 0x04, "Change To Exclude Mode";
AllowNewSources, 0x05, "Allow New Sources";
BlockOldSources, 0x06, "Block Old Sources";
}
);
#[derive(Copy, Clone, Debug, IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned)]
#[repr(C)]
pub struct GroupRecordHeader {
record_type: u8,
aux_data_len: u8,
number_of_sources: U16,
multicast_address: Ipv4Addr,
}
impl GroupRecordHeader {
pub fn number_of_sources(&self) -> u16 {
self.number_of_sources.get()
}
pub fn record_type(&self) -> Result<IgmpGroupRecordType, UnrecognizedProtocolCode<u8>> {
IgmpGroupRecordType::try_from(self.record_type)
}
pub fn multicast_addr(&self) -> &Ipv4Addr {
&self.multicast_address
}
}
pub struct GroupRecord<B> {
header: Ref<B, GroupRecordHeader>,
sources: Ref<B, [Ipv4Addr]>,
}
impl<B: SplitByteSlice> GroupRecord<B> {
pub fn header(&self) -> &GroupRecordHeader {
self.header.deref()
}
pub fn sources(&self) -> &[Ipv4Addr] {
self.sources.deref()
}
}
#[derive(Copy, Clone, Debug)]
pub struct IgmpMembershipReportV3;
impl<B> IgmpNonEmptyBody for Records<B, IgmpMembershipReportV3> {}
impl<B> MessageType<B> for IgmpMembershipReportV3 {
type FixedHeader = MembershipReportV3Data;
type VariableBody = Records<B, IgmpMembershipReportV3>;
type MaxRespTime = ();
const TYPE: IgmpMessageType = IgmpMessageType::MembershipReportV3;
fn parse_body<BV: BufferView<B>>(
header: &Self::FixedHeader,
bytes: BV,
) -> Result<Self::VariableBody, ParseError>
where
B: SplitByteSlice,
{
Records::parse_with_context(bytes.into_rest(), header.number_of_group_records().into())
}
fn body_bytes(body: &Self::VariableBody) -> &[u8]
where
B: SplitByteSlice,
{
body.bytes()
}
}
impl RecordsImplLayout for IgmpMembershipReportV3 {
type Context = usize;
type Error = ParseError;
}
impl RecordsImpl for IgmpMembershipReportV3 {
type Record<'a> = GroupRecord<&'a [u8]>;
fn parse_with_context<'a, BV: BufferView<&'a [u8]>>(
data: &mut BV,
_ctx: &mut usize,
) -> RecordParseResult<GroupRecord<&'a [u8]>, ParseError> {
let header = data
.take_obj_front::<GroupRecordHeader>()
.ok_or_else(debug_err_fn!(ParseError::Format, "Can't take group record header"))?;
let sources = data
.take_slice_front::<Ipv4Addr>(header.number_of_sources().into())
.ok_or_else(debug_err_fn!(ParseError::Format, "Can't group record sources"))?;
let _ = data
.take_front(usize::from(header.aux_data_len) * 4)
.ok_or_else(debug_err_fn!(ParseError::Format, "Can't skip auxiliary data"))?;
Ok(ParsedRecord::Parsed(Self::Record { header, sources }))
}
}
#[derive(Debug)]
pub struct IgmpMembershipReportV1;
impl_igmp_simple_message_type!(IgmpMembershipReportV1, MembershipReportV1, Ipv4Addr);
#[derive(Debug)]
pub struct IgmpMembershipReportV2;
impl_igmp_simple_message_type!(IgmpMembershipReportV2, MembershipReportV2, Ipv4Addr);
#[derive(Debug)]
pub struct IgmpLeaveGroup;
impl_igmp_simple_message_type!(IgmpLeaveGroup, LeaveGroup, Ipv4Addr);
#[allow(missing_docs)]
#[derive(Debug)]
pub enum IgmpPacket<B: SplitByteSlice> {
MembershipQueryV2(IgmpMessage<B, IgmpMembershipQueryV2>),
MembershipQueryV3(IgmpMessage<B, IgmpMembershipQueryV3>),
MembershipReportV1(IgmpMessage<B, IgmpMembershipReportV1>),
MembershipReportV2(IgmpMessage<B, IgmpMembershipReportV2>),
MembershipReportV3(IgmpMessage<B, IgmpMembershipReportV3>),
LeaveGroup(IgmpMessage<B, IgmpLeaveGroup>),
}
impl<B: SplitByteSlice> ParsablePacket<B, ()> for IgmpPacket<B> {
type Error = ParseError;
fn parse_metadata(&self) -> ParseMetadata {
use self::IgmpPacket::*;
match self {
MembershipQueryV2(p) => p.parse_metadata(),
MembershipQueryV3(p) => p.parse_metadata(),
MembershipReportV1(p) => p.parse_metadata(),
MembershipReportV2(p) => p.parse_metadata(),
MembershipReportV3(p) => p.parse_metadata(),
LeaveGroup(p) => p.parse_metadata(),
}
}
fn parse<BV: BufferView<B>>(buffer: BV, args: ()) -> Result<Self, ParseError> {
macro_rules! mtch {
($buffer:expr, $args:expr, $( ($code:ident, $long:tt) => $type:ty, $variant:ident )*) => {
match peek_message_type($buffer.as_ref())? {
$( (IgmpMessageType::$code, $long) => {
let packet = <IgmpMessage<B,$type> as ParsablePacket<_, _>>::parse($buffer, $args)?;
IgmpPacket::$variant(packet)
})*,
}
}
}
Ok(mtch!(
buffer,
args,
(MembershipQuery, false) => IgmpMembershipQueryV2, MembershipQueryV2
(MembershipQuery, true) => IgmpMembershipQueryV3, MembershipQueryV3
(MembershipReportV1, _) => IgmpMembershipReportV1, MembershipReportV1
(MembershipReportV2, _) => IgmpMembershipReportV2, MembershipReportV2
(MembershipReportV3, _) => IgmpMembershipReportV3, MembershipReportV3
(LeaveGroup, _) => IgmpLeaveGroup, LeaveGroup
))
}
}
#[cfg(test)]
mod tests {
use core::fmt::Debug;
use packet::{InnerPacketBuilder, ParseBuffer, Serializer};
use super::*;
use crate::igmp::testdata::*;
use crate::testutil::set_logger_for_test;
const ALL_BUFFERS: [&[u8]; 6] = [
igmp_router_queries::v2::QUERY,
igmp_router_queries::v3::QUERY,
igmp_reports::v1::MEMBER_REPORT,
igmp_reports::v2::MEMBER_REPORT,
igmp_reports::v3::MEMBER_REPORT,
igmp_leave_group::LEAVE_GROUP,
];
fn serialize_to_bytes<B: SplitByteSlice + Debug, M: MessageType<B> + Debug>(
igmp: &IgmpMessage<B, M>,
) -> Vec<u8>
where
M::VariableBody: IgmpNonEmptyBody,
{
M::body_bytes(&igmp.body)
.into_serializer()
.encapsulate(igmp.builder())
.serialize_vec_outer()
.unwrap()
.as_ref()
.to_vec()
}
fn serialize_to_bytes_inner<
B: SplitByteSlice + Debug,
M: MessageType<B, VariableBody = ()> + Debug,
>(
igmp: &IgmpMessage<B, M>,
) -> Vec<u8> {
igmp.builder().into_serializer().serialize_vec_outer().unwrap().as_ref().to_vec()
}
fn test_parse_and_serialize<
B: SplitByteSlice + Debug,
BV: BufferView<B>,
M: MessageType<B> + Debug,
F: FnOnce(&IgmpMessage<B, M>),
>(
req: BV,
check: F,
) where
M::VariableBody: IgmpNonEmptyBody,
{
let orig_req = req.as_ref().to_owned();
let igmp = IgmpMessage::<_, M>::parse(req, ()).unwrap();
check(&igmp);
let data = serialize_to_bytes(&igmp);
assert_eq!(data, orig_req);
}
fn test_parse_and_serialize_inner<
M: for<'a> MessageType<&'a [u8], VariableBody = ()> + Debug,
F: for<'a> FnOnce(&IgmpMessage<&'a [u8], M>),
>(
mut req: &[u8],
check: F,
) {
let orig_req = req;
let igmp = req.parse_with::<_, IgmpMessage<_, M>>(()).unwrap();
check(&igmp);
let data = serialize_to_bytes_inner(&igmp);
assert_eq!(&data[..], orig_req);
}
#[test]
fn membership_query_v2_parse_and_serialize() {
set_logger_for_test();
test_parse_and_serialize_inner::<IgmpMembershipQueryV2, _>(
igmp_router_queries::v2::QUERY,
|igmp| {
assert_eq!(
*igmp.header,
Ipv4Addr::new(igmp_router_queries::v2::HOST_GROUP_ADDRESS)
);
assert_eq!(igmp.prefix.max_resp_code, igmp_router_queries::v2::MAX_RESP_CODE);
},
);
}
#[test]
fn membership_query_v3_parse_and_serialize() {
set_logger_for_test();
let mut req = igmp_router_queries::v3::QUERY;
test_parse_and_serialize::<_, _, IgmpMembershipQueryV3, _>(&mut req, |igmp| {
assert_eq!(igmp.prefix.max_resp_code, igmp_router_queries::v3::MAX_RESP_CODE);
assert_eq!(
igmp.header.group_address,
Ipv4Addr::new(igmp_router_queries::v3::GROUP_ADDRESS)
);
assert_eq!(igmp.header.number_of_sources(), igmp_router_queries::v3::NUMBER_OF_SOURCES);
assert_eq!(
igmp.header.suppress_router_side_processing(),
igmp_router_queries::v3::SUPPRESS_ROUTER_SIDE
);
assert_eq!(igmp.header.querier_robustness_variable(), igmp_router_queries::v3::QRV);
assert_eq!(
igmp.header.querier_query_interval().as_secs() as u32,
igmp_router_queries::v3::QQIC_SECS
);
assert_eq!(igmp.body.len(), igmp_router_queries::v3::NUMBER_OF_SOURCES as usize);
assert_eq!(igmp.body[0], Ipv4Addr::new(igmp_router_queries::v3::SOURCE));
});
}
#[test]
fn membership_report_v3_parse_and_serialize() {
use igmp_reports::v3::*;
set_logger_for_test();
let mut req = MEMBER_REPORT;
test_parse_and_serialize::<_, _, IgmpMembershipReportV3, _>(&mut req, |igmp| {
assert_eq!(igmp.header.number_of_group_records(), NUMBER_OF_RECORDS);
assert_eq!(igmp.prefix.max_resp_code, MAX_RESP_CODE);
let mut iter = igmp.body.iter();
let rec1 = iter.next().unwrap();
assert_eq!(rec1.header().number_of_sources(), NUMBER_OF_SOURCES_1);
assert_eq!(rec1.header().record_type, RECORD_TYPE_1);
assert_eq!(rec1.header().multicast_address, Ipv4Addr::new(MULTICAST_ADDR_1));
assert_eq!(rec1.header().record_type(), Ok(IgmpGroupRecordType::ModeIsInclude));
assert_eq!(rec1.sources().len(), NUMBER_OF_SOURCES_1 as usize);
assert_eq!(rec1.sources()[0], Ipv4Addr::new(SRC_1_1));
assert_eq!(rec1.sources()[1], Ipv4Addr::new(SRC_1_2));
let rec2 = iter.next().unwrap();
assert_eq!(rec2.header().number_of_sources(), NUMBER_OF_SOURCES_2);
assert_eq!(rec2.header().record_type, RECORD_TYPE_2);
assert_eq!(rec2.header().multicast_address, Ipv4Addr::new(MULTICAST_ADDR_2));
assert_eq!(rec2.header().record_type(), Ok(IgmpGroupRecordType::ModeIsExclude));
assert_eq!(rec2.sources().len(), NUMBER_OF_SOURCES_2 as usize);
assert_eq!(rec2.sources()[0], Ipv4Addr::new(SRC_2_1));
assert_eq!(iter.next().is_none(), true);
});
}
#[test]
fn membership_report_v1_parse_and_serialize() {
use igmp_reports::v1;
set_logger_for_test();
test_parse_and_serialize_inner::<IgmpMembershipReportV1, _>(v1::MEMBER_REPORT, |igmp| {
assert_eq!(*igmp.header, Ipv4Addr::new(v1::GROUP_ADDRESS));
});
}
#[test]
fn membership_report_v2_parse_and_serialize() {
use igmp_reports::v2;
set_logger_for_test();
test_parse_and_serialize_inner::<IgmpMembershipReportV2, _>(v2::MEMBER_REPORT, |igmp| {
assert_eq!(*igmp.header, Ipv4Addr::new(v2::GROUP_ADDRESS));
});
}
#[test]
fn leave_group_parse_and_serialize() {
set_logger_for_test();
test_parse_and_serialize_inner::<IgmpLeaveGroup, _>(
igmp_leave_group::LEAVE_GROUP,
|igmp| {
assert_eq!(*igmp.header, Ipv4Addr::new(igmp_leave_group::GROUP_ADDRESS));
},
);
}
#[test]
fn test_unknown_type() {
let mut buff = igmp_invalid_buffers::UNKNOWN_TYPE.to_vec();
let mut buff = buff.as_mut_slice();
let packet = buff.parse_with::<_, IgmpPacket<_>>(());
assert_eq!(packet.is_err(), true);
}
#[test]
fn test_full_parses() {
let mut bufs = ALL_BUFFERS.to_vec();
for buff in bufs.iter_mut() {
let orig_req = &buff[..];
let packet = buff.parse_with::<_, IgmpPacket<_>>(()).unwrap();
let msg_type = match packet {
IgmpPacket::MembershipQueryV2(p) => p.prefix.msg_type,
IgmpPacket::MembershipQueryV3(p) => p.prefix.msg_type,
IgmpPacket::MembershipReportV1(p) => p.prefix.msg_type,
IgmpPacket::MembershipReportV2(p) => p.prefix.msg_type,
IgmpPacket::MembershipReportV3(p) => p.prefix.msg_type,
IgmpPacket::LeaveGroup(p) => p.prefix.msg_type,
};
assert_eq!(msg_type, orig_req[0]);
}
}
#[test]
fn test_partial_parses() {
for buff in ALL_BUFFERS.iter() {
for i in 0..buff.len() {
let partial_buff = &mut &buff[0..i];
let packet = partial_buff.parse_with::<_, IgmpPacket<_>>(());
assert_eq!(packet.is_err(), true)
}
}
}
fn assert_message_length<Message: for<'a> MessageType<&'a [u8], VariableBody = ()>>(
mut ground_truth: &[u8],
) {
let ground_truth_len = ground_truth.len();
let igmp = ground_truth.parse_with::<_, IgmpMessage<&[u8], Message>>(()).unwrap();
let builder_len = igmp.builder().bytes_len();
assert_eq!(builder_len, ground_truth_len);
}
#[test]
fn test_igmp_packet_length() {
assert_message_length::<IgmpMembershipQueryV2>(igmp_router_queries::v2::QUERY);
assert_message_length::<IgmpMembershipReportV1>(igmp_reports::v1::MEMBER_REPORT);
assert_message_length::<IgmpMembershipReportV2>(igmp_reports::v2::MEMBER_REPORT);
assert_message_length::<IgmpLeaveGroup>(igmp_leave_group::LEAVE_GROUP);
}
}