netstack3_ip/
fragmentation.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! IP fragmentation support.
6
7use core::borrow::Borrow;
8use core::fmt::Debug;
9
10use alloc::vec::Vec;
11
12use explicit::UnreachableExt;
13use net_types::ip::{GenericOverIp, Ip, IpInvariant, Ipv4, Ipv6, Mtu};
14use netstack3_base::{Counter, RngContext, Uninstantiable};
15use netstack3_filter::ForwardedPacket;
16use packet::{
17    Buf, BufferMut, EmptyBuf, FragmentedBuffer as _, InnerPacketBuilder as _, Nested,
18    PacketBuilder, PacketConstraints, ParsablePacket, SerializeError, Serializer,
19};
20use packet_formats::ip::FragmentOffset;
21use packet_formats::ipv4::options::Ipv4Option;
22use packet_formats::ipv4::{
23    Ipv4Header as _, Ipv4PacketBuilder, Ipv4PacketBuilderWithOptions, Ipv4PacketRaw,
24};
25use packet_formats::ipv6::{
26    Ipv6PacketBuilder, Ipv6PacketBuilderBeforeFragment, Ipv6PacketBuilderWithFragmentHeader,
27};
28use rand::Rng;
29
30/// The maximum fragment offset that can be expressed in both IPv4 and IPv6
31/// headers. The maximum transmissible body is this value plus the maximum bytes
32/// transmitted in the last fragment.
33// We have 13 bits to express an 8-byte multiple offset.
34const MAX_FRAGMENT_OFFSET: usize = ((1 << 13) - 1) * 8;
35
36pub trait FragmentationIpExt:
37    packet_formats::ip::IpExt<PacketBuilder: AsFragmentableIpPacketBuilder<Self>>
38{
39    /// The IP packet builder for a forwarded packet.
40    type ForwardedFragmentBuilder: FragmentableIpPacketBuilder<Self>;
41    /// An identifier generated at fragmentation time.
42    type FragmentationId: Copy + Debug;
43}
44
45impl FragmentationIpExt for Ipv4 {
46    type ForwardedFragmentBuilder = ForwardedIpv4PacketBuilder;
47    type FragmentationId = ();
48}
49
50impl FragmentationIpExt for Ipv6 {
51    // IPv6 never fragments forwarded packets, only the source node may
52    // fragment.
53    type ForwardedFragmentBuilder = Uninstantiable;
54    type FragmentationId = u32;
55}
56
57/// Fragmentation errors
58#[derive(Debug, Eq, PartialEq, GenericOverIp)]
59#[generic_over_ip()]
60pub enum FragmentationError {
61    /// Fragmentation not allowed.
62    NotAllowed,
63    /// MTU is too small, headers don't fit.
64    MtuTooSmall,
65    /// Body is too long to be fragmented.
66    BodyTooLong,
67    /// Inner serializer reported a size limited exceeded.
68    SizeLimitExceeded,
69}
70
71/// A [`Serializer`] capable of splitting itself into a packet builder and a
72/// pre-serialized body for fragmentation.
73// TODO(https://fxbug.dev/42148826): Ideally we'd be able to generate fragments
74// without requiring the IP body to be dumped into a Vec first. Update this when
75// that support is available in packet and packet_formats.
76pub trait FragmentableIpSerializer<I: FragmentationIpExt>: Serializer {
77    /// The builder for each fragment.
78    type Builder<'a>: FragmentableIpPacketBuilder<I>
79    where
80        Self: 'a;
81    /// The body to be fragmented.
82    ///
83    /// Note that this API is not attempting to reuse buffers in any way. There
84    /// are improvements that can be made here to perhaps avoid allocations and
85    /// yield out reusable bodies, but we're constrained to taking references to
86    /// the serializers here to avoid changing the body which could interfere
87    /// with the higher layers on errors.
88    type Body<'a>: AsRef<[u8]>
89    where
90        Self: 'a;
91
92    /// Returns the inner packet builder for this IP version and a serialized
93    /// body.
94    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError>;
95}
96
97impl<I, S, B> FragmentableIpSerializer<I> for Nested<S, B>
98where
99    I: FragmentationIpExt,
100    S: Serializer,
101    B: AsFragmentableIpPacketBuilder<I> + PacketBuilder,
102{
103    type Builder<'a>
104        = B::Builder<'a>
105    where
106        Self: 'a;
107
108    type Body<'a>
109        = Buf<Vec<u8>>
110    where
111        Self: 'a;
112
113    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
114        let builder = self.outer().try_as_fragmentable()?;
115        let body = self
116            .inner()
117            .serialize_new_buf(PacketConstraints::UNCONSTRAINED, packet::new_buf_vec)
118            .map_err(|e| match e {
119                SerializeError::SizeLimitExceeded => FragmentationError::SizeLimitExceeded,
120            })?;
121        Ok((builder, body))
122    }
123}
124
125#[derive(Debug, Eq, PartialEq, Copy, Clone)]
126pub enum FragmentPosition {
127    First,
128    Middle,
129    Last,
130}
131
132/// The header size constraints for `FragmentableIpPacketBuilder`
133/// implementations.
134pub struct HeaderSizes {
135    first: usize,
136    remaining: usize,
137}
138
139/// A type that may be transformed into a fragmentable ip packet builder.
140pub trait AsFragmentableIpPacketBuilder<I: FragmentationIpExt> {
141    /// The fragmentable packet builder that can be constructed from this type.
142    type Builder<'a>: FragmentableIpPacketBuilder<I>
143    where
144        Self: 'a;
145
146    /// Attempts to extract a `FragmentableIpPacketBuilder` implementation from
147    /// this type, returning an error if it can't be fragmented.
148    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError>;
149}
150
151/// An IP packet builder that can create IP fragments.
152pub trait FragmentableIpPacketBuilder<I: FragmentationIpExt> {
153    /// Returns the portion of the MTU occupied by IP headers.
154    fn header_sizes(&self) -> HeaderSizes;
155
156    /// Returns a builder for fragment at offset `offset`.
157    ///
158    /// `position` carries information if this is the first or last segment, which
159    /// require special logic.
160    fn builder_at(
161        &self,
162        offset: FragmentOffset,
163        position: FragmentPosition,
164        identifier: I::FragmentationId,
165    ) -> impl PacketBuilder + '_;
166}
167
168/// Blanket impl for everything that has a shape to fit in `Ipv4FragmentBuilder`
169/// as a provider for fragmentation.
170impl<B> AsFragmentableIpPacketBuilder<Ipv4> for B
171where
172    B: InnerIpv4FragmentBuilder,
173{
174    type Builder<'a>
175        = Ipv4FragmentBuilder<'a, Self>
176    where
177        Self: 'a;
178
179    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
180        can_fragment_ipv4(self.prefix())?;
181        Ok(Ipv4FragmentBuilder { builder: self })
182    }
183}
184
185/// A trait marking all the IPv4 builder types that can be fragmented with
186/// [`Ipv4FragmentBuilder`].
187trait InnerIpv4FragmentBuilder: PacketBuilder {
188    fn prefix(&self) -> &Ipv4PacketBuilder;
189    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder;
190    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder;
191    fn header_sizes(&self) -> HeaderSizes;
192}
193
194impl InnerIpv4FragmentBuilder for Ipv4PacketBuilder {
195    fn prefix(&self) -> &Ipv4PacketBuilder {
196        self
197    }
198
199    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
200        self
201    }
202
203    fn clone_for_fragment(&self, _position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
204        self.clone()
205    }
206
207    fn header_sizes(&self) -> HeaderSizes {
208        let size = self.constraints().header_len();
209        HeaderSizes { first: size, remaining: size }
210    }
211}
212
213impl<'a, I> InnerIpv4FragmentBuilder for Ipv4PacketBuilderWithOptions<'a, I>
214where
215    I: Iterator<Item: Borrow<Ipv4Option<'a>>> + Clone,
216{
217    fn prefix(&self) -> &Ipv4PacketBuilder {
218        self.prefix_builder()
219    }
220
221    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
222        self.prefix_builder_mut()
223    }
224
225    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
226        self.clone().with_fragment_options(position == FragmentPosition::First)
227    }
228
229    fn header_sizes(&self) -> HeaderSizes {
230        let first = self.constraints().header_len();
231        let remaining = self.clone().with_fragment_options(false).constraints().header_len();
232        HeaderSizes { first, remaining }
233    }
234}
235
236pub struct Ipv4FragmentBuilder<'a, B> {
237    builder: &'a B,
238}
239
240impl<'a, B> FragmentableIpPacketBuilder<Ipv4> for Ipv4FragmentBuilder<'a, B>
241where
242    B: InnerIpv4FragmentBuilder,
243{
244    fn header_sizes(&self) -> HeaderSizes {
245        self.builder.header_sizes()
246    }
247
248    fn builder_at(
249        &self,
250        offset: FragmentOffset,
251        position: FragmentPosition,
252        (): (),
253    ) -> impl PacketBuilder + '_ {
254        let mut builder = self.builder.clone_for_fragment(position);
255        set_ipv4_fragment(builder.prefix_mut(), offset, position);
256        builder
257    }
258}
259
260impl<B> AsFragmentableIpPacketBuilder<Ipv6> for B
261where
262    for<'a> &'a B: Ipv6PacketBuilderBeforeFragment,
263{
264    type Builder<'a>
265        = Ipv6FragmentBuilder<'a, Self>
266    where
267        Self: 'a;
268
269    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
270        Ok(Ipv6FragmentBuilder { builder: self })
271    }
272}
273
274pub struct Ipv6FragmentBuilder<'a, B> {
275    builder: &'a B,
276}
277
278impl<'a, B> FragmentableIpPacketBuilder<Ipv6> for Ipv6FragmentBuilder<'a, B>
279where
280    &'a B: Ipv6PacketBuilderBeforeFragment,
281{
282    fn header_sizes(&self) -> HeaderSizes {
283        // NB: We currently only support headers that need to be in all
284        // fragments, so we only need to calculate once. We might need to change
285        // the trait shape if that changes.
286        let header_len =
287            Ipv6PacketBuilderWithFragmentHeader::new(self.builder, FragmentOffset::ZERO, false, 0)
288                .constraints()
289                .header_len();
290        HeaderSizes { first: header_len, remaining: header_len }
291    }
292
293    fn builder_at(
294        &self,
295        offset: FragmentOffset,
296        position: FragmentPosition,
297        identifier: u32,
298    ) -> impl PacketBuilder + '_ {
299        Ipv6PacketBuilderWithFragmentHeader::new(
300            self.builder,
301            offset,
302            position != FragmentPosition::Last,
303            identifier,
304        )
305    }
306}
307
308impl<I, B> FragmentableIpSerializer<I> for ForwardedPacket<I, B>
309where
310    I: FragmentationIpExt,
311    B: BufferMut,
312{
313    type Builder<'a>
314        = I::ForwardedFragmentBuilder
315    where
316        Self: 'a;
317    type Body<'a>
318        = Buf<&'a [u8]>
319    where
320        Self: 'a;
321
322    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
323        #[derive(GenericOverIp)]
324        #[generic_over_ip(I, Ip)]
325        struct Out<I: FragmentationIpExt>(I::ForwardedFragmentBuilder);
326        I::map_ip::<_, Result<(Out<I>, IpInvariant<Buf<&[u8]>>), FragmentationError>>(
327            self,
328            |forwarded| {
329                // Parse an IPv4 packet from the forwarded packet. We can assert
330                // strongly on all of the parsing here because ForwardedPacket
331                // is guaranteed to have been parsed by the IP stack already.
332                let mut buffer = forwarded.buffer().as_ref();
333                let packet = Ipv4PacketRaw::parse(&mut buffer, ())
334                    .expect("ForwardedPacket must be parseable");
335                let builder = packet.builder();
336                can_fragment_ipv4(&builder)?;
337                let raw_options_bytes = packet
338                    .options()
339                    .as_ref()
340                    .complete()
341                    .expect("unexpected incomplete IP header")
342                    .bytes();
343
344                let mut raw_options = Buf::new(
345                    [0u8; packet_formats::ipv4::MAX_OPTIONS_LEN],
346                    ..raw_options_bytes.len(),
347                );
348                raw_options.as_mut().copy_from_slice(raw_options_bytes);
349                let body = Buf::new(
350                    packet.into_body().complete().expect("unexpected incomplete IP body"),
351                    ..,
352                );
353                Ok((Out(ForwardedIpv4PacketBuilder { builder, raw_options }), IpInvariant(body)))
354            },
355            |_forwarded| Err(FragmentationError::NotAllowed),
356        )
357        .map(|(Out(builder), IpInvariant(body))| (builder, body))
358    }
359}
360
361pub struct ForwardedIpv4PacketBuilder {
362    builder: Ipv4PacketBuilder,
363    raw_options: Buf<[u8; packet_formats::ipv4::MAX_OPTIONS_LEN]>,
364}
365
366impl FragmentableIpPacketBuilder<Ipv4> for ForwardedIpv4PacketBuilder {
367    fn header_sizes(&self) -> HeaderSizes {
368        let Self { builder, raw_options } = self;
369        if raw_options.is_empty() {
370            builder.header_sizes()
371        } else {
372            let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
373                .expect("must hold valid options");
374            Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.iter())
375                .header_sizes()
376        }
377    }
378
379    fn builder_at(
380        &self,
381        offset: FragmentOffset,
382        position: FragmentPosition,
383        (): (),
384    ) -> impl PacketBuilder + '_ {
385        let Self { builder, raw_options } = self;
386        let mut builder = builder.clone();
387        set_ipv4_fragment(&mut builder, offset, position);
388        let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
389            .expect("must hold valid options");
390        Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.into_iter())
391            .with_fragment_options(position == FragmentPosition::First)
392    }
393}
394
395impl<I: FragmentationIpExt> FragmentableIpPacketBuilder<I> for Uninstantiable {
396    fn header_sizes(&self) -> HeaderSizes {
397        self.uninstantiable_unreachable()
398    }
399
400    fn builder_at(
401        &self,
402        _offset: FragmentOffset,
403        _position: FragmentPosition,
404        _identifier: I::FragmentationId,
405    ) -> impl PacketBuilder + '_ {
406        self.uninstantiable_unreachable::<Ipv6PacketBuilder>()
407    }
408}
409
410/// Abstracts fragment ID generation for [`IpFragmenter`].
411///
412/// A blanket impl is provided for [`RngContext`] implementers, so the bindings
413/// context can be used to generate random IDs for IPv6.
414pub(crate) trait FragmentationIdGenContext {
415    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId;
416}
417
418#[derive(GenericOverIp)]
419#[generic_over_ip(I, Ip)]
420struct WrapFragmentationId<I: FragmentationIpExt>(I::FragmentationId);
421
422impl<BC> FragmentationIdGenContext for BC
423where
424    BC: RngContext,
425{
426    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
427        let WrapFragmentationId(identifier) = I::map_ip_out(
428            self,
429            |_| WrapFragmentationId(()),
430            |rng| {
431                // TODO(https://fxbug.dev/373428005): Perhaps we can do better
432                // than a simple RNG. This is currently copying what netstack2
433                // does. RFC 7739 calls out different strategies for fragment
434                // IDs in IPv6. We currently pick an option that is not doing a
435                // best effort to avoid collisions, but it guarantees that
436                // fragment IDs can't be tracked as an attack vector.
437                // We avoid a zero fragment ID like netstack2 does.
438                WrapFragmentationId(rng.rng().gen_range(1..=u32::MAX))
439            },
440        );
441        identifier
442    }
443}
444
445pub(crate) struct IpFragmenter<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I> + 'a> {
446    builder: S::Builder<'a>,
447    body: S::Body<'a>,
448    consumed: usize,
449    max_fragment_body_first: usize,
450    max_fragment_body_remaining: usize,
451    identifier: I::FragmentationId,
452}
453
454/// Trait to allow [`IpFragmenter::next`] to capture all the required lifetimes.
455// TODO(https://github.com/rust-lang/rust/issues/123432): Replace with `impl
456// use<'a, 'b>` when available in tree.
457pub trait Capture<'a, 'b> {}
458impl<'a, 'b, O> Capture<'a, 'b> for O
459where
460    O: 'b,
461    'a: 'b,
462{
463}
464
465/// Returns the biggest fragment body that can fit in `mtu` with a given IP
466/// `header` size.
467///
468/// The returned body size is rounded down to the nearest multiple of 8 to fit
469/// the IP header representation of fragment offsets.
470fn maximum_fragment_body_with_header_and_mtu(
471    mtu: Mtu,
472    header: usize,
473) -> Result<usize, FragmentationError> {
474    let v = usize::from(mtu).checked_sub(header).ok_or(FragmentationError::MtuTooSmall)?;
475    // Mask the final 8 bits since fragment offset is expressed in units
476    // of 8 octets for both IP versions.
477    let v = v & !0x07usize;
478
479    if v == 0 {
480        // Can't fragment if we don't have at least a single 8 octet
481        // of space.
482        return Err(FragmentationError::MtuTooSmall);
483    }
484    Ok(v)
485}
486
487impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
488    /// Creates a new `IpFragmenter` with some `serializer` respecting a maximum
489    /// IP layer `mtu`.
490    pub(crate) fn new<C: FragmentationIdGenContext>(
491        id_ctx: &mut C,
492        serializer: &'a S,
493        mtu: Mtu,
494    ) -> Result<Self, FragmentationError> {
495        let (builder, body) = serializer.builder_and_body()?;
496        let HeaderSizes { first, remaining } = builder.header_sizes();
497        let max_fragment_body_first = maximum_fragment_body_with_header_and_mtu(mtu, first)?;
498        let max_fragment_body_remaining =
499            maximum_fragment_body_with_header_and_mtu(mtu, remaining)?;
500
501        if body.as_ref().len() > MAX_FRAGMENT_OFFSET + max_fragment_body_remaining {
502            return Err(FragmentationError::BodyTooLong);
503        }
504
505        let identifier = id_ctx.generate_id::<I>();
506
507        Ok(Self {
508            builder,
509            body,
510            consumed: 0,
511            max_fragment_body_first,
512            max_fragment_body_remaining,
513            identifier,
514        })
515    }
516
517    /// Returns the serializer for the next segment and a boolean indicating
518    /// whether more fragments are pending, or `None` if all segments have been
519    /// produced.
520    ///
521    /// # Panics
522    ///
523    /// Panics if fragmentation is not necessary for the `serializer` that
524    /// created this `IpFragmenter`.
525    pub(crate) fn next(
526        &mut self,
527    ) -> Option<(impl Serializer<Buffer = EmptyBuf> + Capture<'a, '_>, bool)> {
528        let Self {
529            builder,
530            body,
531            consumed,
532            max_fragment_body_first,
533            max_fragment_body_remaining,
534            identifier,
535        } = self;
536        let body = &AsRef::as_ref(body)[*consumed..];
537        if body.is_empty() {
538            return None;
539        }
540        let first = *consumed == 0;
541        let max_fragment_body =
542            if first { max_fragment_body_first } else { max_fragment_body_remaining };
543        let take = body.len().min(*max_fragment_body);
544        let last = take == body.len();
545        let position = match (first, last) {
546            (true, true) => {
547                panic!("unnecessary fragmentation");
548            }
549            (true, false) => FragmentPosition::First,
550            (false, false) => FragmentPosition::Middle,
551            (false, true) => FragmentPosition::Last,
552        };
553        // Upon construction IpFragmenter verifies that we won't go over the
554        // maximum offset since the body length is known.
555        let fragment_offset = u16::try_from(*consumed).expect("fragment offset too large");
556        // Care is taken above to always take 8-byte multiples to be added to
557        // consumed, so we should always have a good representation for
558        // FragmentOffset.
559        let fragment_offset =
560            FragmentOffset::new_with_bytes(fragment_offset).expect("invalid offset");
561        let fragment_builder = builder.builder_at(fragment_offset, position, *identifier);
562        let end = *consumed + take;
563        let has_more = body.len() > take;
564        let fragment_body = &body[..take];
565        *consumed = end;
566        Some((fragment_body.into_serializer().encapsulate(fragment_builder), has_more))
567    }
568}
569
570fn can_fragment_ipv4(builder: &Ipv4PacketBuilder) -> Result<(), FragmentationError> {
571    if builder.read_df_flag() {
572        return Err(FragmentationError::NotAllowed);
573    }
574    Ok(())
575}
576
577fn set_ipv4_fragment(
578    builder: &mut Ipv4PacketBuilder,
579    offset: FragmentOffset,
580    position: FragmentPosition,
581) {
582    builder.mf_flag(position != FragmentPosition::Last);
583    builder.fragment_offset(offset);
584}
585
586/// Counters kept by the IP stack pertaining to fragmentation.
587#[derive(Default)]
588pub struct FragmentationCounters {
589    /// The number of IP frames requiring fragmentation on egress.
590    pub fragmentation_required: Counter,
591    /// The total number of fragments sent.
592    pub fragments: Counter,
593    /// The number of `NotAllowed` errors encountered.
594    pub error_not_allowed: Counter,
595    /// The number of `MtuTooSmall` errors encountered.
596    pub error_mtu_too_small: Counter,
597    /// The number of `BodyTooLong` errors encountered.
598    pub error_body_too_long: Counter,
599    /// The number of `SizeLimitExceeded` errors encountered.
600    pub error_inner_size_limit_exceeded: Counter,
601    /// Counts the number of times fragmentation was short-circuited due to a
602    /// fragment serialization error.
603    pub error_fragmented_serializer: Counter,
604}
605
606impl FragmentationCounters {
607    pub(crate) fn error_counter(&self, error: FragmentationError) -> &Counter {
608        match error {
609            FragmentationError::NotAllowed => &self.error_not_allowed,
610            FragmentationError::MtuTooSmall => &self.error_mtu_too_small,
611            FragmentationError::BodyTooLong => &self.error_body_too_long,
612            FragmentationError::SizeLimitExceeded => &self.error_inner_size_limit_exceeded,
613        }
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    use assert_matches::assert_matches;
622    use net_types::Witness as _;
623    use netstack3_base::testutil::{TEST_ADDRS_V4, TEST_ADDRS_V6};
624    use packet::{Buffer, BufferView, GrowBuffer};
625    use packet_formats::ip::IpProto;
626    use packet_formats::ipv4::Ipv4Packet;
627    use packet_formats::ipv6::ext_hdrs::Ipv6ExtensionHeaderData;
628    use packet_formats::ipv6::{Ipv6Header, Ipv6Packet};
629    use test_case::test_case;
630
631    const TEST_MTU: Mtu = Ipv6::MINIMUM_LINK_MTU;
632
633    fn gen_body(len: usize) -> Vec<u8> {
634        // Cycle bytes until 251 which is the largest prime that can fit in a
635        // u8. Unlikely this aligns poorly and hides fragmentation bugs.
636        (0u8..=251).cycle().take(len).collect::<Vec<u8>>()
637    }
638
639    impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
640        fn next_serialized(&mut self) -> Buf<Vec<u8>> {
641            self.next()
642                .expect("no more fragments")
643                .0
644                .serialize_vec_outer()
645                .map_err(|(err, _serializer)| err)
646                .unwrap()
647                .unwrap_b()
648        }
649    }
650
651    trait FragmentationTestEnv<I: FragmentationIpExt> {
652        fn new_serializer<'a>(
653            &self,
654            body: &'a [u8],
655        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a;
656        fn check_fragment(
657            &self,
658            fragment: &mut Buf<Vec<u8>>,
659            position: FragmentPosition,
660            offset: usize,
661        );
662    }
663
664    #[derive(Default)]
665    struct Ipv4TestEnv {
666        dont_frag: bool,
667    }
668
669    impl Ipv4TestEnv {
670        const fn dont_frag() -> Self {
671            Self { dont_frag: true }
672        }
673    }
674
675    const IPV4_ID: u16 = 0x1234;
676    fn new_ipv4_packet_builder(dont_frag: bool) -> Ipv4PacketBuilder {
677        let mut builder = Ipv4PacketBuilder::new(
678            TEST_ADDRS_V4.local_ip,
679            TEST_ADDRS_V4.remote_ip,
680            1,
681            IpProto::Udp.into(),
682        );
683        builder.id(IPV4_ID);
684        builder.df_flag(dont_frag);
685        builder
686    }
687
688    fn parse_and_check_ipv4_packet(
689        fragment: &mut Buf<Vec<u8>>,
690        position: FragmentPosition,
691        offset: usize,
692    ) -> Ipv4Packet<&[u8]> {
693        let packet = Ipv4Packet::parse(fragment.buffer_view(), ()).expect("parse fragment");
694        assert_eq!(packet.src_ip(), TEST_ADDRS_V4.local_ip.get());
695        assert_eq!(packet.dst_ip(), TEST_ADDRS_V4.remote_ip.get());
696        assert_eq!(packet.ttl(), 1);
697        assert_eq!(packet.id(), IPV4_ID);
698        assert_eq!(packet.proto(), IpProto::Udp.into());
699        assert_eq!(packet.mf_flag(), position != FragmentPosition::Last);
700        assert_eq!(usize::from(packet.fragment_offset().into_bytes()), offset);
701        packet
702    }
703
704    impl FragmentationTestEnv<Ipv4> for Ipv4TestEnv {
705        fn new_serializer<'a>(
706            &self,
707            body: &'a [u8],
708        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
709            let Self { dont_frag } = self;
710            body.into_serializer().encapsulate(new_ipv4_packet_builder(*dont_frag))
711        }
712
713        fn check_fragment(
714            &self,
715            fragment: &mut Buf<Vec<u8>>,
716            position: FragmentPosition,
717            offset: usize,
718        ) {
719            let _ = parse_and_check_ipv4_packet(fragment, position, offset);
720        }
721    }
722
723    #[derive(Default)]
724    struct Ipv4WithOptionsTestEnv(Ipv4TestEnv);
725
726    // The MSB of an option kind determines if it should be copied.
727    const FAKE_OPTION_COPIED_KIND: u8 = 255;
728    const FAKE_OPTION_COPIED: [u8; 1] = [255];
729    const FAKE_OPTION_NOT_COPIED_KIND: u8 = 127;
730    const FAKE_OPTION_NOT_COPIED: [u8; 1] = [127];
731
732    impl FragmentationTestEnv<Ipv4> for Ipv4WithOptionsTestEnv {
733        fn new_serializer<'a>(
734            &self,
735            body: &'a [u8],
736        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
737            let Self(Ipv4TestEnv { dont_frag }) = self;
738            body.into_serializer().encapsulate(
739                Ipv4PacketBuilderWithOptions::new(
740                    new_ipv4_packet_builder(*dont_frag),
741                    [
742                        Ipv4Option::Unrecognized {
743                            kind: FAKE_OPTION_COPIED_KIND,
744                            data: &FAKE_OPTION_COPIED[..],
745                        },
746                        Ipv4Option::Unrecognized {
747                            kind: FAKE_OPTION_NOT_COPIED_KIND,
748                            data: &FAKE_OPTION_NOT_COPIED[..],
749                        },
750                    ],
751                )
752                .unwrap(),
753            )
754        }
755
756        fn check_fragment(
757            &self,
758            fragment: &mut Buf<Vec<u8>>,
759            position: FragmentPosition,
760            offset: usize,
761        ) {
762            let packet = parse_and_check_ipv4_packet(fragment, position, offset);
763            let (copied, not_copied) = packet.iter_options().fold(
764                (false, false),
765                |(mut copied, mut not_copied), option| {
766                    let (kind, data) = assert_matches!(option,
767                        Ipv4Option::Unrecognized{ kind, data } => (kind, data)
768                    );
769                    assert_eq!(data.len(), 1);
770                    assert_eq!(data[0], kind);
771                    let seen = match kind {
772                        FAKE_OPTION_COPIED_KIND => &mut copied,
773                        FAKE_OPTION_NOT_COPIED_KIND => &mut not_copied,
774                        k => panic!("unexpected option {k}"),
775                    };
776                    assert_eq!(core::mem::replace(seen, true), false);
777                    (copied, not_copied)
778                },
779            );
780            assert_eq!(copied, true, "must be copied on all fragments {position:?}");
781            assert_eq!(
782                not_copied,
783                position == FragmentPosition::First,
784                "must only be in first fragment {position:?}"
785            );
786        }
787    }
788
789    struct ForwardingTestEnv<E>(E);
790    impl<I: FragmentationIpExt, E: FragmentationTestEnv<I>> FragmentationTestEnv<I>
791        for ForwardingTestEnv<E>
792    {
793        fn new_serializer<'a>(
794            &self,
795            body: &'a [u8],
796        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a {
797            use packet_formats::ip::IpPacket as _;
798            let Self(inner) = self;
799            let mut buffer = inner
800                .new_serializer(body)
801                .serialize_outer(packet::NoReuseBufferProvider(packet::new_buf_vec))
802                .map_err(|(err, _)| err)
803                .unwrap();
804            let packet =
805                <I::Packet<_> as ParsablePacket<_, _>>::parse(buffer.buffer_view(), ()).unwrap();
806            let src_addr = packet.src_ip();
807            let dst_addr = packet.dst_ip();
808            let proto = packet.proto();
809            let meta = packet.parse_metadata();
810            drop(packet);
811            ForwardedPacket::new(src_addr, dst_addr, proto, meta, buffer)
812        }
813        fn check_fragment(
814            &self,
815            fragment: &mut Buf<Vec<u8>>,
816            position: FragmentPosition,
817            offset: usize,
818        ) {
819            let Self(inner) = self;
820            inner.check_fragment(fragment, position, offset)
821        }
822    }
823
824    struct Ipv6TestEnv;
825
826    const IPV6_ID: u32 = 0x1234ABCD;
827
828    impl FragmentationTestEnv<Ipv6> for Ipv6TestEnv {
829        fn new_serializer<'a>(
830            &self,
831            body: &'a [u8],
832        ) -> impl FragmentableIpSerializer<Ipv6, Buffer: Buffer> + 'a {
833            body.into_serializer().encapsulate(Ipv6PacketBuilder::new(
834                TEST_ADDRS_V6.local_ip,
835                TEST_ADDRS_V6.remote_ip,
836                1,
837                IpProto::Udp.into(),
838            ))
839        }
840
841        fn check_fragment(
842            &self,
843            fragment: &mut Buf<Vec<u8>>,
844            position: FragmentPosition,
845            offset: usize,
846        ) {
847            let packet = Ipv6Packet::parse(fragment.buffer_view(), ()).unwrap();
848            assert_eq!(packet.src_ip(), TEST_ADDRS_V6.local_ip.get());
849            assert_eq!(packet.dst_ip(), TEST_ADDRS_V6.remote_ip.get());
850            assert_eq!(packet.hop_limit(), 1);
851            assert_eq!(packet.proto(), IpProto::Udp.into());
852            let fragment = packet
853                .iter_extension_hdrs()
854                .find_map(|h| match h.into_data() {
855                    Ipv6ExtensionHeaderData::Fragment { fragment_data } => Some(fragment_data),
856                    _ => None,
857                })
858                .expect("no fragment header");
859            assert_eq!(fragment.identification(), IPV6_ID);
860            assert_eq!(usize::from(fragment.fragment_offset().into_bytes()), offset);
861            assert_eq!(fragment.m_flag(), position != FragmentPosition::Last);
862        }
863    }
864
865    struct FixedIdContext;
866    impl FragmentationIdGenContext for FixedIdContext {
867        fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
868            let WrapFragmentationId(id) =
869                I::map_ip_out((), |()| WrapFragmentationId(()), |()| WrapFragmentationId(IPV6_ID));
870            id
871        }
872    }
873
874    #[test_case::test_matrix(
875        [
876            Ipv4TestEnv::default(),
877            Ipv4WithOptionsTestEnv::default(),
878            ForwardingTestEnv(Ipv4TestEnv::default()),
879            ForwardingTestEnv(Ipv4WithOptionsTestEnv::default()),
880            Ipv6TestEnv,
881        ],
882        0..=2
883    )]
884    fn fragment<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(
885        env: E,
886        middle_fragments: usize,
887    ) {
888        // NB: We're using the fact that MTU is larger than the header sizes
889        // here to end up obtaining the right number of middle fragments as
890        // expected. This makes this test sensitive to the relation between the
891        // picked MTU and the header sizes for the multiple serializers.
892        let full_body = gen_body(usize::from(TEST_MTU) * (1 + middle_fragments));
893        let mut body_view = Buf::new(&full_body[..], ..);
894        let serializer = env.new_serializer(&full_body[..]);
895        let mut fragmenter = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU)
896            .expect("create fragmenter");
897
898        let mut frag = fragmenter.next_serialized();
899        env.check_fragment(&mut frag, FragmentPosition::First, body_view.prefix_len());
900        assert_eq!(
901            frag.as_ref(),
902            body_view.buffer_view().take_front(fragmenter.max_fragment_body_first).unwrap()
903        );
904
905        for _ in 0..middle_fragments {
906            let mut frag = fragmenter.next_serialized();
907            env.check_fragment(&mut frag, FragmentPosition::Middle, body_view.prefix_len());
908            assert_eq!(
909                frag.as_ref(),
910                body_view.buffer_view().take_front(fragmenter.max_fragment_body_remaining).unwrap()
911            );
912        }
913
914        let mut frag = fragmenter.next_serialized();
915        env.check_fragment(&mut frag, FragmentPosition::Last, body_view.prefix_len());
916        assert_eq!(frag.as_ref(), body_view.buffer_view().into_rest());
917
918        // No more fragments.
919        assert!(fragmenter.next().is_none());
920    }
921
922    #[test_case(Ipv4TestEnv::dont_frag())]
923    #[test_case(Ipv4WithOptionsTestEnv(Ipv4TestEnv::dont_frag()))]
924    #[test_case(ForwardingTestEnv(Ipv4TestEnv::dont_frag()))]
925    #[test_case(ForwardingTestEnv(Ipv6TestEnv))]
926    fn not_allowed<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
927        let body = gen_body(usize::from(TEST_MTU));
928        let serializer = env.new_serializer(&body[..]);
929        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
930        assert_eq!(result, Err(FragmentationError::NotAllowed))
931    }
932
933    #[test_case(Ipv4TestEnv::default())]
934    #[test_case(Ipv4WithOptionsTestEnv::default())]
935    #[test_case(ForwardingTestEnv(Ipv4TestEnv::default()))]
936    #[test_case(Ipv6TestEnv)]
937    fn mtu_too_small<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
938        let body = gen_body(usize::from(TEST_MTU));
939        let serializer = env.new_serializer(&body[..]);
940        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, Mtu::new(10)).map(|_| ());
941        assert_eq!(result, Err(FragmentationError::MtuTooSmall));
942    }
943
944    #[test_case(Ipv4TestEnv::default())]
945    #[test_case(Ipv4WithOptionsTestEnv::default())]
946    #[test_case(Ipv6TestEnv)]
947    fn body_too_long<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
948        let body = gen_body(MAX_FRAGMENT_OFFSET + usize::from(TEST_MTU));
949        let serializer = env.new_serializer(&body[..]);
950        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
951        assert_eq!(result, Err(FragmentationError::BodyTooLong));
952    }
953}