use core::borrow::Borrow;
use core::fmt::Debug;
use alloc::vec::Vec;
use explicit::UnreachableExt;
use net_types::ip::{GenericOverIp, Ip, IpInvariant, Ipv4, Ipv6, Mtu};
use netstack3_base::{Counter, RngContext, Uninstantiable};
use netstack3_filter::ForwardedPacket;
use packet::{
Buf, BufferMut, EmptyBuf, FragmentedBuffer as _, InnerPacketBuilder as _, Nested,
PacketBuilder, PacketConstraints, ParsablePacket, SerializeError, Serializer,
};
use packet_formats::ip::FragmentOffset;
use packet_formats::ipv4::options::Ipv4Option;
use packet_formats::ipv4::{
Ipv4Header as _, Ipv4PacketBuilder, Ipv4PacketBuilderWithOptions, Ipv4PacketRaw,
};
use packet_formats::ipv6::{
Ipv6PacketBuilder, Ipv6PacketBuilderBeforeFragment, Ipv6PacketBuilderWithFragmentHeader,
};
use rand::Rng;
const MAX_FRAGMENT_OFFSET: usize = ((1 << 13) - 1) * 8;
pub trait FragmentationIpExt:
packet_formats::ip::IpExt<PacketBuilder: AsFragmentableIpPacketBuilder<Self>>
{
type ForwardedFragmentBuilder: FragmentableIpPacketBuilder<Self>;
type FragmentationId: Copy + Debug;
}
impl FragmentationIpExt for Ipv4 {
type ForwardedFragmentBuilder = ForwardedIpv4PacketBuilder;
type FragmentationId = ();
}
impl FragmentationIpExt for Ipv6 {
type ForwardedFragmentBuilder = Uninstantiable;
type FragmentationId = u32;
}
#[derive(Debug, Eq, PartialEq, GenericOverIp)]
#[generic_over_ip()]
pub enum FragmentationError {
NotAllowed,
MtuTooSmall,
BodyTooLong,
SizeLimitExceeded,
}
pub trait FragmentableIpSerializer<I: FragmentationIpExt>: Serializer {
type Builder<'a>: FragmentableIpPacketBuilder<I>
where
Self: 'a;
type Body<'a>: AsRef<[u8]>
where
Self: 'a;
fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError>;
}
impl<I, S, B> FragmentableIpSerializer<I> for Nested<S, B>
where
I: FragmentationIpExt,
S: Serializer,
B: AsFragmentableIpPacketBuilder<I> + PacketBuilder,
{
type Builder<'a>
= B::Builder<'a>
where
Self: 'a;
type Body<'a>
= Buf<Vec<u8>>
where
Self: 'a;
fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
let builder = self.outer().try_as_fragmentable()?;
let body = self
.inner()
.serialize_new_buf(PacketConstraints::UNCONSTRAINED, packet::new_buf_vec)
.map_err(|e| match e {
SerializeError::SizeLimitExceeded => FragmentationError::SizeLimitExceeded,
})?;
Ok((builder, body))
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum FragmentPosition {
First,
Middle,
Last,
}
pub struct HeaderSizes {
first: usize,
remaining: usize,
}
pub trait AsFragmentableIpPacketBuilder<I: FragmentationIpExt> {
type Builder<'a>: FragmentableIpPacketBuilder<I>
where
Self: 'a;
fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError>;
}
pub trait FragmentableIpPacketBuilder<I: FragmentationIpExt> {
fn header_sizes(&self) -> HeaderSizes;
fn builder_at(
&self,
offset: FragmentOffset,
position: FragmentPosition,
identifier: I::FragmentationId,
) -> impl PacketBuilder + '_;
}
impl<B> AsFragmentableIpPacketBuilder<Ipv4> for B
where
B: InnerIpv4FragmentBuilder,
{
type Builder<'a>
= Ipv4FragmentBuilder<'a, Self>
where
Self: 'a;
fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
can_fragment_ipv4(self.prefix())?;
Ok(Ipv4FragmentBuilder { builder: self })
}
}
trait InnerIpv4FragmentBuilder: PacketBuilder {
fn prefix(&self) -> &Ipv4PacketBuilder;
fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder;
fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder;
fn header_sizes(&self) -> HeaderSizes;
}
impl InnerIpv4FragmentBuilder for Ipv4PacketBuilder {
fn prefix(&self) -> &Ipv4PacketBuilder {
self
}
fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
self
}
fn clone_for_fragment(&self, _position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
self.clone()
}
fn header_sizes(&self) -> HeaderSizes {
let size = self.constraints().header_len();
HeaderSizes { first: size, remaining: size }
}
}
impl<'a, I> InnerIpv4FragmentBuilder for Ipv4PacketBuilderWithOptions<'a, I>
where
I: Iterator<Item: Borrow<Ipv4Option<'a>>> + Clone,
{
fn prefix(&self) -> &Ipv4PacketBuilder {
self.prefix_builder()
}
fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
self.prefix_builder_mut()
}
fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
self.clone().with_fragment_options(position == FragmentPosition::First)
}
fn header_sizes(&self) -> HeaderSizes {
let first = self.constraints().header_len();
let remaining = self.clone().with_fragment_options(false).constraints().header_len();
HeaderSizes { first, remaining }
}
}
pub struct Ipv4FragmentBuilder<'a, B> {
builder: &'a B,
}
impl<'a, B> FragmentableIpPacketBuilder<Ipv4> for Ipv4FragmentBuilder<'a, B>
where
B: InnerIpv4FragmentBuilder,
{
fn header_sizes(&self) -> HeaderSizes {
self.builder.header_sizes()
}
fn builder_at(
&self,
offset: FragmentOffset,
position: FragmentPosition,
(): (),
) -> impl PacketBuilder + '_ {
let mut builder = self.builder.clone_for_fragment(position);
set_ipv4_fragment(builder.prefix_mut(), offset, position);
builder
}
}
impl<B> AsFragmentableIpPacketBuilder<Ipv6> for B
where
for<'a> &'a B: Ipv6PacketBuilderBeforeFragment,
{
type Builder<'a>
= Ipv6FragmentBuilder<'a, Self>
where
Self: 'a;
fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
Ok(Ipv6FragmentBuilder { builder: self })
}
}
pub struct Ipv6FragmentBuilder<'a, B> {
builder: &'a B,
}
impl<'a, B> FragmentableIpPacketBuilder<Ipv6> for Ipv6FragmentBuilder<'a, B>
where
&'a B: Ipv6PacketBuilderBeforeFragment,
{
fn header_sizes(&self) -> HeaderSizes {
let header_len =
Ipv6PacketBuilderWithFragmentHeader::new(self.builder, FragmentOffset::ZERO, false, 0)
.constraints()
.header_len();
HeaderSizes { first: header_len, remaining: header_len }
}
fn builder_at(
&self,
offset: FragmentOffset,
position: FragmentPosition,
identifier: u32,
) -> impl PacketBuilder + '_ {
Ipv6PacketBuilderWithFragmentHeader::new(
self.builder,
offset,
position != FragmentPosition::Last,
identifier,
)
}
}
impl<I, B> FragmentableIpSerializer<I> for ForwardedPacket<I, B>
where
I: FragmentationIpExt,
B: BufferMut,
{
type Builder<'a>
= I::ForwardedFragmentBuilder
where
Self: 'a;
type Body<'a>
= Buf<&'a [u8]>
where
Self: 'a;
fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
#[derive(GenericOverIp)]
#[generic_over_ip(I, Ip)]
struct Out<I: FragmentationIpExt>(I::ForwardedFragmentBuilder);
I::map_ip::<_, Result<(Out<I>, IpInvariant<Buf<&[u8]>>), FragmentationError>>(
self,
|forwarded| {
let mut buffer = forwarded.buffer().as_ref();
let packet = Ipv4PacketRaw::parse(&mut buffer, ())
.expect("ForwardedPacket must be parseable");
let builder = packet.builder();
can_fragment_ipv4(&builder)?;
let raw_options_bytes = packet
.options()
.as_ref()
.complete()
.expect("unexpected incomplete IP header")
.bytes();
let mut raw_options = Buf::new(
[0u8; packet_formats::ipv4::MAX_OPTIONS_LEN],
..raw_options_bytes.len(),
);
raw_options.as_mut().copy_from_slice(raw_options_bytes);
let body = Buf::new(
packet.into_body().complete().expect("unexpected incomplete IP body"),
..,
);
Ok((Out(ForwardedIpv4PacketBuilder { builder, raw_options }), IpInvariant(body)))
},
|_forwarded| Err(FragmentationError::NotAllowed),
)
.map(|(Out(builder), IpInvariant(body))| (builder, body))
}
}
pub struct ForwardedIpv4PacketBuilder {
builder: Ipv4PacketBuilder,
raw_options: Buf<[u8; packet_formats::ipv4::MAX_OPTIONS_LEN]>,
}
impl FragmentableIpPacketBuilder<Ipv4> for ForwardedIpv4PacketBuilder {
fn header_sizes(&self) -> HeaderSizes {
let Self { builder, raw_options } = self;
if raw_options.is_empty() {
builder.header_sizes()
} else {
let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
.expect("must hold valid options");
Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.iter())
.header_sizes()
}
}
fn builder_at(
&self,
offset: FragmentOffset,
position: FragmentPosition,
(): (),
) -> impl PacketBuilder + '_ {
let Self { builder, raw_options } = self;
let mut builder = builder.clone();
set_ipv4_fragment(&mut builder, offset, position);
let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
.expect("must hold valid options");
Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.into_iter())
.with_fragment_options(position == FragmentPosition::First)
}
}
impl<I: FragmentationIpExt> FragmentableIpPacketBuilder<I> for Uninstantiable {
fn header_sizes(&self) -> HeaderSizes {
self.uninstantiable_unreachable()
}
fn builder_at(
&self,
_offset: FragmentOffset,
_position: FragmentPosition,
_identifier: I::FragmentationId,
) -> impl PacketBuilder + '_ {
self.uninstantiable_unreachable::<Ipv6PacketBuilder>()
}
}
pub(crate) trait FragmentationIdGenContext {
fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId;
}
#[derive(GenericOverIp)]
#[generic_over_ip(I, Ip)]
struct WrapFragmentationId<I: FragmentationIpExt>(I::FragmentationId);
impl<BC> FragmentationIdGenContext for BC
where
BC: RngContext,
{
fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
let WrapFragmentationId(identifier) = I::map_ip_out(
self,
|_| WrapFragmentationId(()),
|rng| {
WrapFragmentationId(rng.rng().gen_range(1..=u32::MAX))
},
);
identifier
}
}
pub(crate) struct IpFragmenter<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I> + 'a> {
builder: S::Builder<'a>,
body: S::Body<'a>,
consumed: usize,
max_fragment_body_first: usize,
max_fragment_body_remaining: usize,
identifier: I::FragmentationId,
}
pub trait Capture<'a, 'b> {}
impl<'a, 'b, O> Capture<'a, 'b> for O
where
O: 'b,
'a: 'b,
{
}
fn maximum_fragment_body_with_header_and_mtu(
mtu: Mtu,
header: usize,
) -> Result<usize, FragmentationError> {
let v = usize::from(mtu).checked_sub(header).ok_or(FragmentationError::MtuTooSmall)?;
let v = v & !0x07usize;
if v == 0 {
return Err(FragmentationError::MtuTooSmall);
}
Ok(v)
}
impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
pub(crate) fn new<C: FragmentationIdGenContext>(
id_ctx: &mut C,
serializer: &'a S,
mtu: Mtu,
) -> Result<Self, FragmentationError> {
let (builder, body) = serializer.builder_and_body()?;
let HeaderSizes { first, remaining } = builder.header_sizes();
let max_fragment_body_first = maximum_fragment_body_with_header_and_mtu(mtu, first)?;
let max_fragment_body_remaining =
maximum_fragment_body_with_header_and_mtu(mtu, remaining)?;
if body.as_ref().len() > MAX_FRAGMENT_OFFSET + max_fragment_body_remaining {
return Err(FragmentationError::BodyTooLong);
}
let identifier = id_ctx.generate_id::<I>();
Ok(Self {
builder,
body,
consumed: 0,
max_fragment_body_first,
max_fragment_body_remaining,
identifier,
})
}
pub(crate) fn next(&mut self) -> Option<impl Serializer<Buffer = EmptyBuf> + Capture<'a, '_>> {
let Self {
builder,
body,
consumed,
max_fragment_body_first,
max_fragment_body_remaining,
identifier,
} = self;
let body = &AsRef::as_ref(body)[*consumed..];
if body.is_empty() {
return None;
}
let first = *consumed == 0;
let max_fragment_body =
if first { max_fragment_body_first } else { max_fragment_body_remaining };
let take = body.len().min(*max_fragment_body);
let last = take == body.len();
let position = match (first, last) {
(true, true) => {
panic!("unnecessary fragmentation");
}
(true, false) => FragmentPosition::First,
(false, false) => FragmentPosition::Middle,
(false, true) => FragmentPosition::Last,
};
let fragment_offset = u16::try_from(*consumed).expect("fragment offset too large");
let fragment_offset =
FragmentOffset::new_with_bytes(fragment_offset).expect("invalid offset");
let fragment_builder = builder.builder_at(fragment_offset, position, *identifier);
let end = *consumed + take;
let fragment_body = &body[..take];
*consumed = end;
Some(fragment_body.into_serializer().encapsulate(fragment_builder))
}
}
fn can_fragment_ipv4(builder: &Ipv4PacketBuilder) -> Result<(), FragmentationError> {
if builder.read_df_flag() {
return Err(FragmentationError::NotAllowed);
}
Ok(())
}
fn set_ipv4_fragment(
builder: &mut Ipv4PacketBuilder,
offset: FragmentOffset,
position: FragmentPosition,
) {
builder.mf_flag(position != FragmentPosition::Last);
builder.fragment_offset(offset);
}
#[derive(Default)]
pub struct FragmentationCounters {
pub fragmentation_required: Counter,
pub fragments: Counter,
pub error_not_allowed: Counter,
pub error_mtu_too_small: Counter,
pub error_body_too_long: Counter,
pub error_inner_size_limit_exceeded: Counter,
pub error_fragmented_serializer: Counter,
}
impl FragmentationCounters {
pub(crate) fn error_counter(&self, error: FragmentationError) -> &Counter {
match error {
FragmentationError::NotAllowed => &self.error_not_allowed,
FragmentationError::MtuTooSmall => &self.error_mtu_too_small,
FragmentationError::BodyTooLong => &self.error_body_too_long,
FragmentationError::SizeLimitExceeded => &self.error_inner_size_limit_exceeded,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use net_types::Witness as _;
use netstack3_base::testutil::{TEST_ADDRS_V4, TEST_ADDRS_V6};
use packet::{Buffer, BufferView, GrowBuffer};
use packet_formats::ip::IpProto;
use packet_formats::ipv4::Ipv4Packet;
use packet_formats::ipv6::ext_hdrs::Ipv6ExtensionHeaderData;
use packet_formats::ipv6::{Ipv6Header, Ipv6Packet};
use test_case::test_case;
const TEST_MTU: Mtu = Ipv6::MINIMUM_LINK_MTU;
fn gen_body(len: usize) -> Vec<u8> {
(0u8..=251).cycle().take(len).collect::<Vec<u8>>()
}
impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
fn next_serialized(&mut self) -> Buf<Vec<u8>> {
self.next()
.expect("no more fragments")
.serialize_vec_outer()
.map_err(|(err, _serializer)| err)
.unwrap()
.unwrap_b()
}
}
trait FragmentationTestEnv<I: FragmentationIpExt> {
fn new_serializer<'a>(
&self,
body: &'a [u8],
) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a;
fn check_fragment(
&self,
fragment: &mut Buf<Vec<u8>>,
position: FragmentPosition,
offset: usize,
);
}
#[derive(Default)]
struct Ipv4TestEnv {
dont_frag: bool,
}
impl Ipv4TestEnv {
const fn dont_frag() -> Self {
Self { dont_frag: true }
}
}
const IPV4_ID: u16 = 0x1234;
fn new_ipv4_packet_builder(dont_frag: bool) -> Ipv4PacketBuilder {
let mut builder = Ipv4PacketBuilder::new(
TEST_ADDRS_V4.local_ip,
TEST_ADDRS_V4.remote_ip,
1,
IpProto::Udp.into(),
);
builder.id(IPV4_ID);
builder.df_flag(dont_frag);
builder
}
fn parse_and_check_ipv4_packet(
fragment: &mut Buf<Vec<u8>>,
position: FragmentPosition,
offset: usize,
) -> Ipv4Packet<&[u8]> {
let packet = Ipv4Packet::parse(fragment.buffer_view(), ()).expect("parse fragment");
assert_eq!(packet.src_ip(), TEST_ADDRS_V4.local_ip.get());
assert_eq!(packet.dst_ip(), TEST_ADDRS_V4.remote_ip.get());
assert_eq!(packet.ttl(), 1);
assert_eq!(packet.id(), IPV4_ID);
assert_eq!(packet.proto(), IpProto::Udp.into());
assert_eq!(packet.mf_flag(), position != FragmentPosition::Last);
assert_eq!(usize::from(packet.fragment_offset().into_bytes()), offset);
packet
}
impl FragmentationTestEnv<Ipv4> for Ipv4TestEnv {
fn new_serializer<'a>(
&self,
body: &'a [u8],
) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
let Self { dont_frag } = self;
body.into_serializer().encapsulate(new_ipv4_packet_builder(*dont_frag))
}
fn check_fragment(
&self,
fragment: &mut Buf<Vec<u8>>,
position: FragmentPosition,
offset: usize,
) {
let _ = parse_and_check_ipv4_packet(fragment, position, offset);
}
}
#[derive(Default)]
struct Ipv4WithOptionsTestEnv(Ipv4TestEnv);
const FAKE_OPTION_COPIED_KIND: u8 = 255;
const FAKE_OPTION_COPIED: [u8; 1] = [255];
const FAKE_OPTION_NOT_COPIED_KIND: u8 = 127;
const FAKE_OPTION_NOT_COPIED: [u8; 1] = [127];
impl FragmentationTestEnv<Ipv4> for Ipv4WithOptionsTestEnv {
fn new_serializer<'a>(
&self,
body: &'a [u8],
) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
let Self(Ipv4TestEnv { dont_frag }) = self;
body.into_serializer().encapsulate(
Ipv4PacketBuilderWithOptions::new(
new_ipv4_packet_builder(*dont_frag),
[
Ipv4Option::Unrecognized {
kind: FAKE_OPTION_COPIED_KIND,
data: &FAKE_OPTION_COPIED[..],
},
Ipv4Option::Unrecognized {
kind: FAKE_OPTION_NOT_COPIED_KIND,
data: &FAKE_OPTION_NOT_COPIED[..],
},
],
)
.unwrap(),
)
}
fn check_fragment(
&self,
fragment: &mut Buf<Vec<u8>>,
position: FragmentPosition,
offset: usize,
) {
let packet = parse_and_check_ipv4_packet(fragment, position, offset);
let (copied, not_copied) = packet.iter_options().fold(
(false, false),
|(mut copied, mut not_copied), option| {
let (kind, data) = assert_matches!(option,
Ipv4Option::Unrecognized{ kind, data } => (kind, data)
);
assert_eq!(data.len(), 1);
assert_eq!(data[0], kind);
let seen = match kind {
FAKE_OPTION_COPIED_KIND => &mut copied,
FAKE_OPTION_NOT_COPIED_KIND => &mut not_copied,
k => panic!("unexpected option {k}"),
};
assert_eq!(core::mem::replace(seen, true), false);
(copied, not_copied)
},
);
assert_eq!(copied, true, "must be copied on all fragments {position:?}");
assert_eq!(
not_copied,
position == FragmentPosition::First,
"must only be in first fragment {position:?}"
);
}
}
struct ForwardingTestEnv<E>(E);
impl<I: FragmentationIpExt, E: FragmentationTestEnv<I>> FragmentationTestEnv<I>
for ForwardingTestEnv<E>
{
fn new_serializer<'a>(
&self,
body: &'a [u8],
) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a {
use packet_formats::ip::IpPacket as _;
let Self(inner) = self;
let mut buffer = inner
.new_serializer(body)
.serialize_outer(packet::NoReuseBufferProvider(packet::new_buf_vec))
.map_err(|(err, _)| err)
.unwrap();
let packet =
<I::Packet<_> as ParsablePacket<_, _>>::parse(buffer.buffer_view(), ()).unwrap();
let src_addr = packet.src_ip();
let dst_addr = packet.dst_ip();
let proto = packet.proto();
let meta = packet.parse_metadata();
drop(packet);
ForwardedPacket::new(src_addr, dst_addr, proto, meta, buffer)
}
fn check_fragment(
&self,
fragment: &mut Buf<Vec<u8>>,
position: FragmentPosition,
offset: usize,
) {
let Self(inner) = self;
inner.check_fragment(fragment, position, offset)
}
}
struct Ipv6TestEnv;
const IPV6_ID: u32 = 0x1234ABCD;
impl FragmentationTestEnv<Ipv6> for Ipv6TestEnv {
fn new_serializer<'a>(
&self,
body: &'a [u8],
) -> impl FragmentableIpSerializer<Ipv6, Buffer: Buffer> + 'a {
body.into_serializer().encapsulate(Ipv6PacketBuilder::new(
TEST_ADDRS_V6.local_ip,
TEST_ADDRS_V6.remote_ip,
1,
IpProto::Udp.into(),
))
}
fn check_fragment(
&self,
fragment: &mut Buf<Vec<u8>>,
position: FragmentPosition,
offset: usize,
) {
let packet = Ipv6Packet::parse(fragment.buffer_view(), ()).unwrap();
assert_eq!(packet.src_ip(), TEST_ADDRS_V6.local_ip.get());
assert_eq!(packet.dst_ip(), TEST_ADDRS_V6.remote_ip.get());
assert_eq!(packet.hop_limit(), 1);
assert_eq!(packet.proto(), IpProto::Udp.into());
let fragment = packet
.iter_extension_hdrs()
.find_map(|h| match h.into_data() {
Ipv6ExtensionHeaderData::Fragment { fragment_data } => Some(fragment_data),
_ => None,
})
.expect("no fragment header");
assert_eq!(fragment.identification(), IPV6_ID);
assert_eq!(usize::from(fragment.fragment_offset().into_bytes()), offset);
assert_eq!(fragment.m_flag(), position != FragmentPosition::Last);
}
}
struct FixedIdContext;
impl FragmentationIdGenContext for FixedIdContext {
fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
let WrapFragmentationId(id) =
I::map_ip_out((), |()| WrapFragmentationId(()), |()| WrapFragmentationId(IPV6_ID));
id
}
}
#[test_case::test_matrix(
[
Ipv4TestEnv::default(),
Ipv4WithOptionsTestEnv::default(),
ForwardingTestEnv(Ipv4TestEnv::default()),
ForwardingTestEnv(Ipv4WithOptionsTestEnv::default()),
Ipv6TestEnv,
],
0..=2
)]
fn fragment<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(
env: E,
middle_fragments: usize,
) {
let full_body = gen_body(usize::from(TEST_MTU) * (1 + middle_fragments));
let mut body_view = Buf::new(&full_body[..], ..);
let serializer = env.new_serializer(&full_body[..]);
let mut fragmenter = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU)
.expect("create fragmenter");
let mut frag = fragmenter.next_serialized();
env.check_fragment(&mut frag, FragmentPosition::First, body_view.prefix_len());
assert_eq!(
frag.as_ref(),
body_view.buffer_view().take_front(fragmenter.max_fragment_body_first).unwrap()
);
for _ in 0..middle_fragments {
let mut frag = fragmenter.next_serialized();
env.check_fragment(&mut frag, FragmentPosition::Middle, body_view.prefix_len());
assert_eq!(
frag.as_ref(),
body_view.buffer_view().take_front(fragmenter.max_fragment_body_remaining).unwrap()
);
}
let mut frag = fragmenter.next_serialized();
env.check_fragment(&mut frag, FragmentPosition::Last, body_view.prefix_len());
assert_eq!(frag.as_ref(), body_view.buffer_view().into_rest());
assert!(fragmenter.next().is_none());
}
#[test_case(Ipv4TestEnv::dont_frag())]
#[test_case(Ipv4WithOptionsTestEnv(Ipv4TestEnv::dont_frag()))]
#[test_case(ForwardingTestEnv(Ipv4TestEnv::dont_frag()))]
#[test_case(ForwardingTestEnv(Ipv6TestEnv))]
fn not_allowed<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
let body = gen_body(usize::from(TEST_MTU));
let serializer = env.new_serializer(&body[..]);
let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
assert_eq!(result, Err(FragmentationError::NotAllowed))
}
#[test_case(Ipv4TestEnv::default())]
#[test_case(Ipv4WithOptionsTestEnv::default())]
#[test_case(ForwardingTestEnv(Ipv4TestEnv::default()))]
#[test_case(Ipv6TestEnv)]
fn mtu_too_small<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
let body = gen_body(usize::from(TEST_MTU));
let serializer = env.new_serializer(&body[..]);
let result = IpFragmenter::new(&mut FixedIdContext, &serializer, Mtu::new(10)).map(|_| ());
assert_eq!(result, Err(FragmentationError::MtuTooSmall));
}
#[test_case(Ipv4TestEnv::default())]
#[test_case(Ipv4WithOptionsTestEnv::default())]
#[test_case(Ipv6TestEnv)]
fn body_too_long<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
let body = gen_body(MAX_FRAGMENT_OFFSET + usize::from(TEST_MTU));
let serializer = env.new_serializer(&body[..]);
let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
assert_eq!(result, Err(FragmentationError::BodyTooLong));
}
}