Skip to main content

netstack3_ip/
fragmentation.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! IP fragmentation support.
6
7use core::borrow::Borrow;
8use core::fmt::Debug;
9
10use alloc::vec::Vec;
11
12use explicit::UnreachableExt;
13use net_types::ip::{GenericOverIp, Ip, IpInvariant, Ipv4, Ipv6, Mtu};
14use netstack3_base::{
15    Counter, NetworkSerializationContext, NetworkSerializer, RngContext, Uninstantiable,
16};
17use netstack3_filter::ForwardedPacket;
18use packet::{
19    Buf, BufferMut, EmptyBuf, FragmentedBuffer as _, InnerPacketBuilder as _,
20    NestablePacketBuilder, Nested, PacketBuilder, PacketConstraints, ParsablePacket,
21    SerializeError,
22};
23use packet_formats::ip::FragmentOffset;
24use packet_formats::ipv4::options::Ipv4Option;
25use packet_formats::ipv4::{
26    Ipv4Header as _, Ipv4PacketBuilder, Ipv4PacketBuilderWithOptions, Ipv4PacketRaw,
27};
28use packet_formats::ipv6::{
29    Ipv6PacketBuilder, Ipv6PacketBuilderBeforeFragment, Ipv6PacketBuilderWithFragmentHeader,
30};
31use rand::Rng;
32
33/// The maximum fragment offset that can be expressed in both IPv4 and IPv6
34/// headers. The maximum transmissible body is this value plus the maximum bytes
35/// transmitted in the last fragment.
36// We have 13 bits to express an 8-byte multiple offset.
37const MAX_FRAGMENT_OFFSET: usize = ((1 << 13) - 1) * 8;
38
39pub trait FragmentationIpExt:
40    packet_formats::ip::IpExt<
41        PacketBuilder<NetworkSerializationContext>: AsFragmentableIpPacketBuilder<Self>,
42    >
43{
44    /// The IP packet builder for a forwarded packet.
45    type ForwardedFragmentBuilder: FragmentableIpPacketBuilder<Self>;
46    /// An identifier generated at fragmentation time.
47    type FragmentationId: Copy + Debug;
48}
49
50impl FragmentationIpExt for Ipv4 {
51    type ForwardedFragmentBuilder = ForwardedIpv4PacketBuilder;
52    type FragmentationId = ();
53}
54
55impl FragmentationIpExt for Ipv6 {
56    // IPv6 never fragments forwarded packets, only the source node may
57    // fragment.
58    type ForwardedFragmentBuilder = Uninstantiable;
59    type FragmentationId = u32;
60}
61
62/// Fragmentation errors
63#[derive(Debug, Eq, PartialEq, GenericOverIp)]
64#[generic_over_ip()]
65pub enum FragmentationError {
66    /// Fragmentation not allowed.
67    NotAllowed,
68    /// MTU is too small, headers don't fit.
69    MtuTooSmall,
70    /// Body is too long to be fragmented.
71    BodyTooLong,
72    /// Inner serializer reported a size limited exceeded.
73    SizeLimitExceeded,
74}
75
76/// A [`Serializer`] capable of splitting itself into a packet builder and a
77/// pre-serialized body for fragmentation.
78// TODO(https://fxbug.dev/42148826): Ideally we'd be able to generate fragments
79// without requiring the IP body to be dumped into a Vec first. Update this when
80// that support is available in packet and packet_formats.
81pub trait FragmentableIpSerializer<I: FragmentationIpExt>: NetworkSerializer {
82    /// The builder for each fragment.
83    type Builder<'a>: FragmentableIpPacketBuilder<I>
84    where
85        Self: 'a;
86    /// The body to be fragmented.
87    ///
88    /// Note that this API is not attempting to reuse buffers in any way. There
89    /// are improvements that can be made here to perhaps avoid allocations and
90    /// yield out reusable bodies, but we're constrained to taking references to
91    /// the serializers here to avoid changing the body which could interfere
92    /// with the higher layers on errors.
93    type Body<'a>: AsRef<[u8]>
94    where
95        Self: 'a;
96
97    /// Returns the inner packet builder for this IP version and a serialized
98    /// body.
99    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError>;
100}
101
102impl<I, S, B> FragmentableIpSerializer<I> for Nested<S, B>
103where
104    I: FragmentationIpExt,
105    S: NetworkSerializer,
106    B: AsFragmentableIpPacketBuilder<I> + PacketBuilder<NetworkSerializationContext>,
107{
108    type Builder<'a>
109        = B::Builder<'a>
110    where
111        Self: 'a;
112
113    type Body<'a>
114        = Buf<Vec<u8>>
115    where
116        Self: 'a;
117
118    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
119        let builder = self.outer().try_as_fragmentable()?;
120        let body = self
121            .inner()
122            .serialize_new_buf(
123                &mut NetworkSerializationContext::default(),
124                PacketConstraints::UNCONSTRAINED,
125                packet::new_buf_vec,
126            )
127            .map_err(|e| match e {
128                SerializeError::SizeLimitExceeded => FragmentationError::SizeLimitExceeded,
129            })?;
130        Ok((builder, body))
131    }
132}
133
134#[derive(Debug, Eq, PartialEq, Copy, Clone)]
135pub enum FragmentPosition {
136    First,
137    Middle,
138    Last,
139}
140
141/// The header size constraints for `FragmentableIpPacketBuilder`
142/// implementations.
143pub struct HeaderSizes {
144    first: usize,
145    remaining: usize,
146}
147
148/// A type that may be transformed into a fragmentable ip packet builder.
149pub trait AsFragmentableIpPacketBuilder<I: FragmentationIpExt> {
150    /// The fragmentable packet builder that can be constructed from this type.
151    type Builder<'a>: FragmentableIpPacketBuilder<I>
152    where
153        Self: 'a;
154
155    /// Attempts to extract a `FragmentableIpPacketBuilder` implementation from
156    /// this type, returning an error if it can't be fragmented.
157    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError>;
158}
159
160/// An IP packet builder that can create IP fragments.
161pub trait FragmentableIpPacketBuilder<I: FragmentationIpExt> {
162    /// Returns the portion of the MTU occupied by IP headers.
163    fn header_sizes(&self) -> HeaderSizes;
164
165    /// Returns a builder for fragment at offset `offset`.
166    ///
167    /// `position` carries information if this is the first or last segment, which
168    /// require special logic.
169    fn builder_at(
170        &self,
171        offset: FragmentOffset,
172        position: FragmentPosition,
173        identifier: I::FragmentationId,
174    ) -> impl PacketBuilder<NetworkSerializationContext> + '_;
175}
176
177/// Blanket impl for everything that has a shape to fit in `Ipv4FragmentBuilder`
178/// as a provider for fragmentation.
179impl<B> AsFragmentableIpPacketBuilder<Ipv4> for B
180where
181    B: InnerIpv4FragmentBuilder,
182{
183    type Builder<'a>
184        = Ipv4FragmentBuilder<'a, Self>
185    where
186        Self: 'a;
187
188    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
189        can_fragment_ipv4(self.prefix())?;
190        Ok(Ipv4FragmentBuilder { builder: self })
191    }
192}
193
194/// A trait marking all the IPv4 builder types that can be fragmented with
195/// [`Ipv4FragmentBuilder`].
196trait InnerIpv4FragmentBuilder: PacketBuilder<NetworkSerializationContext> {
197    fn prefix(&self) -> &Ipv4PacketBuilder;
198    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder;
199    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder;
200    fn header_sizes(&self) -> HeaderSizes;
201}
202
203impl InnerIpv4FragmentBuilder for Ipv4PacketBuilder {
204    fn prefix(&self) -> &Ipv4PacketBuilder {
205        self
206    }
207
208    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
209        self
210    }
211
212    fn clone_for_fragment(&self, _position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
213        self.clone()
214    }
215
216    fn header_sizes(&self) -> HeaderSizes {
217        let size = self.constraints().header_len();
218        HeaderSizes { first: size, remaining: size }
219    }
220}
221
222impl<'a, I> InnerIpv4FragmentBuilder for Ipv4PacketBuilderWithOptions<'a, I>
223where
224    I: Iterator<Item: Borrow<Ipv4Option<'a>>> + Clone,
225{
226    fn prefix(&self) -> &Ipv4PacketBuilder {
227        self.prefix_builder()
228    }
229
230    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
231        self.prefix_builder_mut()
232    }
233
234    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
235        self.clone().with_fragment_options(position == FragmentPosition::First)
236    }
237
238    fn header_sizes(&self) -> HeaderSizes {
239        let first = self.constraints().header_len();
240        let remaining = self.clone().with_fragment_options(false).constraints().header_len();
241        HeaderSizes { first, remaining }
242    }
243}
244
245pub struct Ipv4FragmentBuilder<'a, B> {
246    builder: &'a B,
247}
248
249impl<'a, B> FragmentableIpPacketBuilder<Ipv4> for Ipv4FragmentBuilder<'a, B>
250where
251    B: InnerIpv4FragmentBuilder,
252{
253    fn header_sizes(&self) -> HeaderSizes {
254        self.builder.header_sizes()
255    }
256
257    fn builder_at(
258        &self,
259        offset: FragmentOffset,
260        position: FragmentPosition,
261        (): (),
262    ) -> impl PacketBuilder<NetworkSerializationContext> + '_ {
263        let mut builder = self.builder.clone_for_fragment(position);
264        set_ipv4_fragment(builder.prefix_mut(), offset, position);
265        builder
266    }
267}
268
269impl<B> AsFragmentableIpPacketBuilder<Ipv6> for B
270where
271    for<'a> &'a B: Ipv6PacketBuilderBeforeFragment,
272{
273    type Builder<'a>
274        = Ipv6FragmentBuilder<'a, Self>
275    where
276        Self: 'a;
277
278    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
279        Ok(Ipv6FragmentBuilder { builder: self })
280    }
281}
282
283pub struct Ipv6FragmentBuilder<'a, B> {
284    builder: &'a B,
285}
286
287impl<'a, B> FragmentableIpPacketBuilder<Ipv6> for Ipv6FragmentBuilder<'a, B>
288where
289    &'a B: Ipv6PacketBuilderBeforeFragment,
290{
291    fn header_sizes(&self) -> HeaderSizes {
292        // NB: We currently only support headers that need to be in all
293        // fragments, so we only need to calculate once. We might need to change
294        // the trait shape if that changes.
295        let header_len =
296            Ipv6PacketBuilderWithFragmentHeader::new(self.builder, FragmentOffset::ZERO, false, 0)
297                .constraints()
298                .header_len();
299        HeaderSizes { first: header_len, remaining: header_len }
300    }
301
302    fn builder_at(
303        &self,
304        offset: FragmentOffset,
305        position: FragmentPosition,
306        identifier: u32,
307    ) -> impl PacketBuilder<NetworkSerializationContext> + '_ {
308        Ipv6PacketBuilderWithFragmentHeader::new(
309            self.builder,
310            offset,
311            position != FragmentPosition::Last,
312            identifier,
313        )
314    }
315}
316
317impl<I, B> FragmentableIpSerializer<I> for ForwardedPacket<I, B>
318where
319    I: FragmentationIpExt,
320    B: BufferMut + NetworkSerializer,
321{
322    type Builder<'a>
323        = I::ForwardedFragmentBuilder
324    where
325        Self: 'a;
326    type Body<'a>
327        = Buf<&'a [u8]>
328    where
329        Self: 'a;
330
331    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
332        #[derive(GenericOverIp)]
333        #[generic_over_ip(I, Ip)]
334        struct Out<I: FragmentationIpExt>(I::ForwardedFragmentBuilder);
335        I::map_ip::<_, Result<(Out<I>, IpInvariant<Buf<&[u8]>>), FragmentationError>>(
336            self,
337            |forwarded| {
338                // Parse an IPv4 packet from the forwarded packet. We can assert
339                // strongly on all of the parsing here because ForwardedPacket
340                // is guaranteed to have been parsed by the IP stack already.
341                let mut buffer = forwarded.buffer().as_ref();
342                let packet = Ipv4PacketRaw::parse(&mut buffer, ())
343                    .expect("ForwardedPacket must be parseable");
344                let builder = packet.builder();
345                can_fragment_ipv4(&builder)?;
346                let raw_options_bytes = packet
347                    .options()
348                    .as_ref()
349                    .complete()
350                    .expect("unexpected incomplete IP header")
351                    .bytes();
352
353                let mut raw_options = Buf::new(
354                    [0u8; packet_formats::ipv4::MAX_OPTIONS_LEN],
355                    ..raw_options_bytes.len(),
356                );
357                raw_options.as_mut().copy_from_slice(raw_options_bytes);
358                let body = Buf::new(
359                    packet.into_body().complete().expect("unexpected incomplete IP body"),
360                    ..,
361                );
362                Ok((Out(ForwardedIpv4PacketBuilder { builder, raw_options }), IpInvariant(body)))
363            },
364            |_forwarded| Err(FragmentationError::NotAllowed),
365        )
366        .map(|(Out(builder), IpInvariant(body))| (builder, body))
367    }
368}
369
370pub struct ForwardedIpv4PacketBuilder {
371    builder: Ipv4PacketBuilder,
372    raw_options: Buf<[u8; packet_formats::ipv4::MAX_OPTIONS_LEN]>,
373}
374
375impl FragmentableIpPacketBuilder<Ipv4> for ForwardedIpv4PacketBuilder {
376    fn header_sizes(&self) -> HeaderSizes {
377        let Self { builder, raw_options } = self;
378        if raw_options.is_empty() {
379            builder.header_sizes()
380        } else {
381            let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
382                .expect("must hold valid options");
383            Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.iter())
384                .header_sizes()
385        }
386    }
387
388    fn builder_at(
389        &self,
390        offset: FragmentOffset,
391        position: FragmentPosition,
392        (): (),
393    ) -> impl PacketBuilder<NetworkSerializationContext> + '_ {
394        let Self { builder, raw_options } = self;
395        let mut builder = builder.clone();
396        set_ipv4_fragment(&mut builder, offset, position);
397        let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
398            .expect("must hold valid options");
399        Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.into_iter())
400            .with_fragment_options(position == FragmentPosition::First)
401    }
402}
403
404impl<I: FragmentationIpExt> FragmentableIpPacketBuilder<I> for Uninstantiable {
405    fn header_sizes(&self) -> HeaderSizes {
406        self.uninstantiable_unreachable()
407    }
408
409    fn builder_at(
410        &self,
411        _offset: FragmentOffset,
412        _position: FragmentPosition,
413        _identifier: I::FragmentationId,
414    ) -> impl PacketBuilder<NetworkSerializationContext> + '_ {
415        self.uninstantiable_unreachable::<Ipv6PacketBuilder>()
416    }
417}
418
419/// Abstracts fragment ID generation for [`IpFragmenter`].
420///
421/// A blanket impl is provided for [`RngContext`] implementers, so the bindings
422/// context can be used to generate random IDs for IPv6.
423pub(crate) trait FragmentationIdGenContext {
424    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId;
425}
426
427#[derive(GenericOverIp)]
428#[generic_over_ip(I, Ip)]
429struct WrapFragmentationId<I: FragmentationIpExt>(I::FragmentationId);
430
431impl<BC> FragmentationIdGenContext for BC
432where
433    BC: RngContext,
434{
435    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
436        let WrapFragmentationId(identifier) = I::map_ip_out(
437            self,
438            |_| WrapFragmentationId(()),
439            |rng| {
440                // TODO(https://fxbug.dev/373428005): Perhaps we can do better
441                // than a simple RNG. This is currently copying what netstack2
442                // does. RFC 7739 calls out different strategies for fragment
443                // IDs in IPv6. We currently pick an option that is not doing a
444                // best effort to avoid collisions, but it guarantees that
445                // fragment IDs can't be tracked as an attack vector.
446                // We avoid a zero fragment ID like netstack2 does.
447                WrapFragmentationId(rng.rng().random_range(1..=u32::MAX))
448            },
449        );
450        identifier
451    }
452}
453
454pub(crate) struct IpFragmenter<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I> + 'a> {
455    builder: S::Builder<'a>,
456    body: S::Body<'a>,
457    consumed: usize,
458    max_fragment_body_first: usize,
459    max_fragment_body_remaining: usize,
460    identifier: I::FragmentationId,
461}
462
463/// Returns the biggest fragment body that can fit in `mtu` with a given IP
464/// `header` size.
465///
466/// The returned body size is rounded down to the nearest multiple of 8 to fit
467/// the IP header representation of fragment offsets.
468fn maximum_fragment_body_with_header_and_mtu(
469    mtu: Mtu,
470    header: usize,
471) -> Result<usize, FragmentationError> {
472    let v = usize::from(mtu).checked_sub(header).ok_or(FragmentationError::MtuTooSmall)?;
473    // Mask the final 8 bits since fragment offset is expressed in units
474    // of 8 octets for both IP versions.
475    let v = v & !0x07usize;
476
477    if v == 0 {
478        // Can't fragment if we don't have at least a single 8 octet
479        // of space.
480        return Err(FragmentationError::MtuTooSmall);
481    }
482    Ok(v)
483}
484
485impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
486    /// Creates a new `IpFragmenter` with some `serializer` respecting a maximum
487    /// IP layer `mtu`.
488    pub(crate) fn new<C: FragmentationIdGenContext>(
489        id_ctx: &mut C,
490        serializer: &'a S,
491        mtu: Mtu,
492    ) -> Result<Self, FragmentationError> {
493        let (builder, body) = serializer.builder_and_body()?;
494        let HeaderSizes { first, remaining } = builder.header_sizes();
495        let max_fragment_body_first = maximum_fragment_body_with_header_and_mtu(mtu, first)?;
496        let max_fragment_body_remaining =
497            maximum_fragment_body_with_header_and_mtu(mtu, remaining)?;
498
499        if body.as_ref().len() > MAX_FRAGMENT_OFFSET + max_fragment_body_remaining {
500            return Err(FragmentationError::BodyTooLong);
501        }
502
503        let identifier = id_ctx.generate_id::<I>();
504
505        Ok(Self {
506            builder,
507            body,
508            consumed: 0,
509            max_fragment_body_first,
510            max_fragment_body_remaining,
511            identifier,
512        })
513    }
514
515    /// Returns the serializer for the next segment and a boolean indicating
516    /// whether more fragments are pending, or `None` if all segments have been
517    /// produced.
518    ///
519    /// # Panics
520    ///
521    /// Panics if fragmentation is not necessary for the `serializer` that
522    /// created this `IpFragmenter`.
523    pub(crate) fn next(&mut self) -> Option<(impl NetworkSerializer<Buffer = EmptyBuf>, bool)> {
524        let Self {
525            builder,
526            body,
527            consumed,
528            max_fragment_body_first,
529            max_fragment_body_remaining,
530            identifier,
531        } = self;
532        let body = &AsRef::as_ref(body)[*consumed..];
533        if body.is_empty() {
534            return None;
535        }
536        let first = *consumed == 0;
537        let max_fragment_body =
538            if first { max_fragment_body_first } else { max_fragment_body_remaining };
539        let take = body.len().min(*max_fragment_body);
540        let last = take == body.len();
541        let position = match (first, last) {
542            (true, true) => {
543                panic!("unnecessary fragmentation");
544            }
545            (true, false) => FragmentPosition::First,
546            (false, false) => FragmentPosition::Middle,
547            (false, true) => FragmentPosition::Last,
548        };
549        // Upon construction IpFragmenter verifies that we won't go over the
550        // maximum offset since the body length is known.
551        let fragment_offset = u16::try_from(*consumed).expect("fragment offset too large");
552        // Care is taken above to always take 8-byte multiples to be added to
553        // consumed, so we should always have a good representation for
554        // FragmentOffset.
555        let fragment_offset =
556            FragmentOffset::new_with_bytes(fragment_offset).expect("invalid offset");
557        let fragment_builder = builder.builder_at(fragment_offset, position, *identifier);
558        let end = *consumed + take;
559        let has_more = body.len() > take;
560        let fragment_body = &body[..take];
561        *consumed = end;
562        Some((fragment_builder.wrap_body(fragment_body.into_serializer()), has_more))
563    }
564}
565
566fn can_fragment_ipv4(builder: &Ipv4PacketBuilder) -> Result<(), FragmentationError> {
567    if builder.read_df_flag() {
568        return Err(FragmentationError::NotAllowed);
569    }
570    Ok(())
571}
572
573fn set_ipv4_fragment(
574    builder: &mut Ipv4PacketBuilder,
575    offset: FragmentOffset,
576    position: FragmentPosition,
577) {
578    builder.mf_flag(position != FragmentPosition::Last);
579    builder.fragment_offset(offset);
580}
581
582/// Counters kept by the IP stack pertaining to fragmentation.
583#[derive(Default, Debug)]
584#[cfg_attr(
585    any(test, feature = "testutils"),
586    derive(PartialEq, netstack3_macros::CounterCollection)
587)]
588pub struct FragmentationCounters<C = Counter> {
589    /// The number of IP frames requiring fragmentation on egress.
590    pub fragmentation_required: C,
591    /// The total number of fragments sent.
592    pub fragments: C,
593    /// The number of `NotAllowed` errors encountered.
594    pub error_not_allowed: C,
595    /// The number of `MtuTooSmall` errors encountered.
596    pub error_mtu_too_small: C,
597    /// The number of `BodyTooLong` errors encountered.
598    pub error_body_too_long: C,
599    /// The number of `SizeLimitExceeded` errors encountered.
600    pub error_inner_size_limit_exceeded: C,
601    /// Counts the number of times fragmentation was short-circuited due to a
602    /// fragment serialization error.
603    pub error_fragmented_serializer: C,
604}
605
606impl FragmentationCounters {
607    pub(crate) fn error_counter(&self, error: &FragmentationError) -> &Counter {
608        match error {
609            FragmentationError::NotAllowed => &self.error_not_allowed,
610            FragmentationError::MtuTooSmall => &self.error_mtu_too_small,
611            FragmentationError::BodyTooLong => &self.error_body_too_long,
612            FragmentationError::SizeLimitExceeded => &self.error_inner_size_limit_exceeded,
613        }
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    use assert_matches::assert_matches;
622    use net_types::Witness as _;
623    use netstack3_base::testutil::{TEST_ADDRS_V4, TEST_ADDRS_V6};
624    use netstack3_filter::FilterIpExt;
625    use packet::{Buffer, BufferView, GrowBuffer, Serializer};
626    use packet_formats::ip::IpProto;
627    use packet_formats::ipv4::Ipv4Packet;
628    use packet_formats::ipv6::ext_hdrs::Ipv6ExtensionHeaderData;
629    use packet_formats::ipv6::{Ipv6Header, Ipv6Packet};
630    use test_case::test_case;
631
632    const TEST_MTU: Mtu = Ipv6::MINIMUM_LINK_MTU;
633
634    fn gen_body(len: usize) -> Vec<u8> {
635        // Cycle bytes until 251 which is the largest prime that can fit in a
636        // u8. Unlikely this aligns poorly and hides fragmentation bugs.
637        (0u8..=251).cycle().take(len).collect::<Vec<u8>>()
638    }
639
640    impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
641        fn next_serialized(&mut self) -> Buf<Vec<u8>> {
642            self.next()
643                .expect("no more fragments")
644                .0
645                .serialize_vec_outer(&mut NetworkSerializationContext::default())
646                .map_err(|(err, _serializer)| err)
647                .unwrap()
648                .unwrap_b()
649        }
650    }
651
652    trait FragmentationTestEnv<I: FragmentationIpExt> {
653        fn new_serializer<'a>(
654            &self,
655            body: &'a [u8],
656        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a;
657        fn check_fragment(
658            &self,
659            fragment: &mut Buf<Vec<u8>>,
660            position: FragmentPosition,
661            offset: usize,
662        );
663    }
664
665    #[derive(Default)]
666    struct Ipv4TestEnv {
667        dont_frag: bool,
668    }
669
670    impl Ipv4TestEnv {
671        const fn dont_frag() -> Self {
672            Self { dont_frag: true }
673        }
674    }
675
676    const IPV4_ID: u16 = 0x1234;
677    fn new_ipv4_packet_builder(dont_frag: bool) -> Ipv4PacketBuilder {
678        let mut builder = Ipv4PacketBuilder::new(
679            TEST_ADDRS_V4.local_ip,
680            TEST_ADDRS_V4.remote_ip,
681            1,
682            IpProto::Udp.into(),
683        );
684        builder.id(IPV4_ID);
685        builder.df_flag(dont_frag);
686        builder
687    }
688
689    fn parse_and_check_ipv4_packet(
690        fragment: &mut Buf<Vec<u8>>,
691        position: FragmentPosition,
692        offset: usize,
693    ) -> Ipv4Packet<&[u8]> {
694        let packet = Ipv4Packet::parse(fragment.buffer_view(), ()).expect("parse fragment");
695        assert_eq!(packet.src_ip(), TEST_ADDRS_V4.local_ip.get());
696        assert_eq!(packet.dst_ip(), TEST_ADDRS_V4.remote_ip.get());
697        assert_eq!(packet.ttl(), 1);
698        assert_eq!(packet.id(), IPV4_ID);
699        assert_eq!(packet.proto(), IpProto::Udp.into());
700        assert_eq!(packet.mf_flag(), position != FragmentPosition::Last);
701        assert_eq!(usize::from(packet.fragment_offset().into_bytes()), offset);
702        packet
703    }
704
705    impl FragmentationTestEnv<Ipv4> for Ipv4TestEnv {
706        fn new_serializer<'a>(
707            &self,
708            body: &'a [u8],
709        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
710            let Self { dont_frag } = self;
711            new_ipv4_packet_builder(*dont_frag).wrap_body(body.into_serializer())
712        }
713
714        fn check_fragment(
715            &self,
716            fragment: &mut Buf<Vec<u8>>,
717            position: FragmentPosition,
718            offset: usize,
719        ) {
720            let _ = parse_and_check_ipv4_packet(fragment, position, offset);
721        }
722    }
723
724    #[derive(Default)]
725    struct Ipv4WithOptionsTestEnv(Ipv4TestEnv);
726
727    // The MSB of an option kind determines if it should be copied.
728    const FAKE_OPTION_COPIED_KIND: u8 = 255;
729    const FAKE_OPTION_COPIED: [u8; 1] = [255];
730    const FAKE_OPTION_NOT_COPIED_KIND: u8 = 127;
731    const FAKE_OPTION_NOT_COPIED: [u8; 1] = [127];
732
733    impl FragmentationTestEnv<Ipv4> for Ipv4WithOptionsTestEnv {
734        fn new_serializer<'a>(
735            &self,
736            body: &'a [u8],
737        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
738            let Self(Ipv4TestEnv { dont_frag }) = self;
739
740            Ipv4PacketBuilderWithOptions::new(
741                new_ipv4_packet_builder(*dont_frag),
742                [
743                    Ipv4Option::Unrecognized {
744                        kind: FAKE_OPTION_COPIED_KIND,
745                        data: &FAKE_OPTION_COPIED[..],
746                    },
747                    Ipv4Option::Unrecognized {
748                        kind: FAKE_OPTION_NOT_COPIED_KIND,
749                        data: &FAKE_OPTION_NOT_COPIED[..],
750                    },
751                ],
752            )
753            .unwrap()
754            .wrap_body(body.into_serializer())
755        }
756
757        fn check_fragment(
758            &self,
759            fragment: &mut Buf<Vec<u8>>,
760            position: FragmentPosition,
761            offset: usize,
762        ) {
763            let packet = parse_and_check_ipv4_packet(fragment, position, offset);
764            let (copied, not_copied) = packet.iter_options().fold(
765                (false, false),
766                |(mut copied, mut not_copied), option| {
767                    let (kind, data) = assert_matches!(option,
768                        Ipv4Option::Unrecognized{ kind, data } => (kind, data)
769                    );
770                    assert_eq!(data.len(), 1);
771                    assert_eq!(data[0], kind);
772                    let seen = match kind {
773                        FAKE_OPTION_COPIED_KIND => &mut copied,
774                        FAKE_OPTION_NOT_COPIED_KIND => &mut not_copied,
775                        k => panic!("unexpected option {k}"),
776                    };
777                    assert_eq!(core::mem::replace(seen, true), false);
778                    (copied, not_copied)
779                },
780            );
781            assert_eq!(copied, true, "must be copied on all fragments {position:?}");
782            assert_eq!(
783                not_copied,
784                position == FragmentPosition::First,
785                "must only be in first fragment {position:?}"
786            );
787        }
788    }
789
790    struct ForwardingTestEnv<E>(E);
791    impl<I: FragmentationIpExt + FilterIpExt, E: FragmentationTestEnv<I>> FragmentationTestEnv<I>
792        for ForwardingTestEnv<E>
793    {
794        fn new_serializer<'a>(
795            &self,
796            body: &'a [u8],
797        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a {
798            use packet_formats::ip::IpPacket as _;
799            let Self(inner) = self;
800            let mut buffer = inner
801                .new_serializer(body)
802                .serialize_outer(
803                    &mut NetworkSerializationContext::default(),
804                    packet::NoReuseBufferProvider(packet::new_buf_vec),
805                )
806                .map_err(|(err, _)| err)
807                .unwrap();
808            let packet =
809                <I::Packet<_> as ParsablePacket<_, _>>::parse(buffer.buffer_view(), ()).unwrap();
810            let src_addr = packet.src_ip();
811            let dst_addr = packet.dst_ip();
812            let proto = packet.proto();
813            let meta = packet.parse_metadata();
814            drop(packet);
815            ForwardedPacket::new(src_addr, dst_addr, proto, meta, buffer)
816        }
817        fn check_fragment(
818            &self,
819            fragment: &mut Buf<Vec<u8>>,
820            position: FragmentPosition,
821            offset: usize,
822        ) {
823            let Self(inner) = self;
824            inner.check_fragment(fragment, position, offset)
825        }
826    }
827
828    struct Ipv6TestEnv;
829
830    const IPV6_ID: u32 = 0x1234ABCD;
831
832    impl FragmentationTestEnv<Ipv6> for Ipv6TestEnv {
833        fn new_serializer<'a>(
834            &self,
835            body: &'a [u8],
836        ) -> impl FragmentableIpSerializer<Ipv6, Buffer: Buffer> + 'a {
837            Ipv6PacketBuilder::new(
838                TEST_ADDRS_V6.local_ip,
839                TEST_ADDRS_V6.remote_ip,
840                1,
841                IpProto::Udp.into(),
842            )
843            .wrap_body(body.into_serializer())
844        }
845
846        fn check_fragment(
847            &self,
848            fragment: &mut Buf<Vec<u8>>,
849            position: FragmentPosition,
850            offset: usize,
851        ) {
852            let packet = Ipv6Packet::parse(fragment.buffer_view(), ()).unwrap();
853            assert_eq!(packet.src_ip(), TEST_ADDRS_V6.local_ip.get());
854            assert_eq!(packet.dst_ip(), TEST_ADDRS_V6.remote_ip.get());
855            assert_eq!(packet.hop_limit(), 1);
856            assert_eq!(packet.proto(), IpProto::Udp.into());
857            let fragment = packet
858                .iter_extension_hdrs()
859                .find_map(|h| match h.into_data() {
860                    Ipv6ExtensionHeaderData::Fragment { fragment_data } => Some(fragment_data),
861                    _ => None,
862                })
863                .expect("no fragment header");
864            assert_eq!(fragment.identification(), IPV6_ID);
865            assert_eq!(usize::from(fragment.fragment_offset().into_bytes()), offset);
866            assert_eq!(fragment.m_flag(), position != FragmentPosition::Last);
867        }
868    }
869
870    struct FixedIdContext;
871    impl FragmentationIdGenContext for FixedIdContext {
872        fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
873            let WrapFragmentationId(id) =
874                I::map_ip_out((), |()| WrapFragmentationId(()), |()| WrapFragmentationId(IPV6_ID));
875            id
876        }
877    }
878
879    #[test_case::test_matrix(
880        [
881            Ipv4TestEnv::default(),
882            Ipv4WithOptionsTestEnv::default(),
883            ForwardingTestEnv(Ipv4TestEnv::default()),
884            ForwardingTestEnv(Ipv4WithOptionsTestEnv::default()),
885            Ipv6TestEnv,
886        ],
887        0..=2
888    )]
889    fn fragment<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(
890        env: E,
891        middle_fragments: usize,
892    ) {
893        // NB: We're using the fact that MTU is larger than the header sizes
894        // here to end up obtaining the right number of middle fragments as
895        // expected. This makes this test sensitive to the relation between the
896        // picked MTU and the header sizes for the multiple serializers.
897        let full_body = gen_body(usize::from(TEST_MTU) * (1 + middle_fragments));
898        let mut body_view = Buf::new(&full_body[..], ..);
899        let serializer = env.new_serializer(&full_body[..]);
900        let mut fragmenter = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU)
901            .expect("create fragmenter");
902
903        let mut frag = fragmenter.next_serialized();
904        env.check_fragment(&mut frag, FragmentPosition::First, body_view.prefix_len());
905        assert_eq!(
906            frag.as_ref(),
907            body_view.buffer_view().take_front(fragmenter.max_fragment_body_first).unwrap()
908        );
909
910        for _ in 0..middle_fragments {
911            let mut frag = fragmenter.next_serialized();
912            env.check_fragment(&mut frag, FragmentPosition::Middle, body_view.prefix_len());
913            assert_eq!(
914                frag.as_ref(),
915                body_view.buffer_view().take_front(fragmenter.max_fragment_body_remaining).unwrap()
916            );
917        }
918
919        let mut frag = fragmenter.next_serialized();
920        env.check_fragment(&mut frag, FragmentPosition::Last, body_view.prefix_len());
921        assert_eq!(frag.as_ref(), body_view.buffer_view().into_rest());
922
923        // No more fragments.
924        assert!(fragmenter.next().is_none());
925    }
926
927    #[test_case(Ipv4TestEnv::dont_frag())]
928    #[test_case(Ipv4WithOptionsTestEnv(Ipv4TestEnv::dont_frag()))]
929    #[test_case(ForwardingTestEnv(Ipv4TestEnv::dont_frag()))]
930    #[test_case(ForwardingTestEnv(Ipv6TestEnv))]
931    fn not_allowed<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
932        let body = gen_body(usize::from(TEST_MTU));
933        let serializer = env.new_serializer(&body[..]);
934        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
935        assert_eq!(result, Err(FragmentationError::NotAllowed))
936    }
937
938    #[test_case(Ipv4TestEnv::default())]
939    #[test_case(Ipv4WithOptionsTestEnv::default())]
940    #[test_case(ForwardingTestEnv(Ipv4TestEnv::default()))]
941    #[test_case(Ipv6TestEnv)]
942    fn mtu_too_small<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
943        let body = gen_body(usize::from(TEST_MTU));
944        let serializer = env.new_serializer(&body[..]);
945        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, Mtu::new(10)).map(|_| ());
946        assert_eq!(result, Err(FragmentationError::MtuTooSmall));
947    }
948
949    #[test_case(Ipv4TestEnv::default())]
950    #[test_case(Ipv4WithOptionsTestEnv::default())]
951    #[test_case(Ipv6TestEnv)]
952    fn body_too_long<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
953        let body = gen_body(MAX_FRAGMENT_OFFSET + usize::from(TEST_MTU));
954        let serializer = env.new_serializer(&body[..]);
955        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
956        assert_eq!(result, Err(FragmentationError::BodyTooLong));
957    }
958}